Neural Network DSL

Copy Markdown

ExTorch provides an Elixir DSL for defining neural network architectures. Models can be created with random weights for experimentation, or loaded with pre-trained weights from TorchScript files.

Defining a module

defmodule MyClassifier do
  use ExTorch.NN.Module

  deflayer :conv1, ExTorch.NN.Conv2d, in_channels: 1, out_channels: 32, kernel_size: 3
  deflayer :relu1, ExTorch.NN.ReLU
  deflayer :pool, ExTorch.NN.MaxPool2d, kernel_size: 2
  deflayer :flatten, ExTorch.NN.Flatten
  deflayer :fc, ExTorch.NN.Linear, in_features: 32 * 13 * 13, out_features: 10

  def forward(model, x) do
    x
    |> layer(model, :conv1)
    |> layer(model, :relu1)
    |> layer(model, :pool)
    |> layer(model, :flatten)
    |> layer(model, :fc)
  end
end

deflayer declares a named layer with its type and options at compile time. layer/3 applies a layer during the forward pass.

Creating and using a model

# Random weights
model = MyClassifier.new()
input = ExTorch.randn({1, 1, 28, 28})
output = MyClassifier.forward(model, input)
# => %ExTorch.Tensor{size: {1, 10}, ...}

Inspecting parameters

params = MyClassifier.parameters(model)
# => [
#   {"conv1.weight", #Tensor<[32, 1, 3, 3]>},
#   {"conv1.bias", #Tensor<[32]>},
#   {"fc.weight", #Tensor<[10, 5408]>},
#   {"fc.bias", #Tensor<[10]>}
# ]

Loading pre-trained weights

There are two ways to use weights trained in Python:

Option A: from_jit -- Use the JIT model directly

model = MyClassifier.from_jit("trained_classifier.pt")
output = MyClassifier.predict(model, [input])

The JIT model's forward method handles all computation. The DSL definition is validated against the .pt file's submodules at load time -- if the architectures don't match, you get a clear error.

Option B: load_weights -- Copy weights into DSL layers

model = MyClassifier.load_weights("trained_classifier.pt")
output = MyClassifier.forward(model, input)

This creates the DSL layers, then copies matching parameter tensors from the .pt file. The result is a regular DSL model that runs through your Elixir forward/2 function.

When to use which:

from_jitload_weights
RunsPython's forward logicYour Elixir forward logic
Use whenYou want exact Python behaviorYour forward differs (custom post-processing, different dropout)
Returns%JITBackedModel{}%{layer => %Layer{}}

Generating DSL from existing models

Don't have a DSL definition yet? ExTorch can introspect any .pt file and generate the Elixir source:

model = ExTorch.JIT.load("resnet18.pt")
source = ExTorch.NN.Introspect.to_elixir(model, "ResNet18")
IO.puts(source)

Output:

defmodule ResNet18 do
  use ExTorch.NN.Module

  deflayer :conv1, ExTorch.NN.Conv2d, in_channels: 3, out_channels: 64, kernel_size: 7
  deflayer :bn1, ExTorch.NN.BatchNorm2d
  deflayer :relu, ExTorch.NN.ReLU
  # ... (full architecture)

  def forward(x) do
    x
    |> layer(:conv1)
    |> layer(:bn1)
    |> layer(:relu)
    # ...
  end
end

You can paste this into your project and customize it.

Available layers

Convolutions

Pooling

Normalization

Recurrent

Attention

Activations

ReLU, LeakyReLU, GELU, ELU, SiLU (Swish), Mish, PReLU, Sigmoid, Tanh, Softmax, LogSoftmax

Other