Table of Contents
- What You'll Learn
- Prerequisites
- Lesson 1: Tensors — The Building Blocks
- Lesson 2: Your First Neural Network
- Lesson 3: Training a Classifier
- Lesson 4: Understanding Loss Functions
- Lesson 5: Optimizers and Learning Rates
- Lesson 6: Overfitting and Regularization
- Lesson 7: Working with Real Data
- Lesson 8: Inference and Deployment
- Lesson 9: GPU-Accelerated Numerical Functions
- Lesson 10: Putting It All Together
What You'll Learn
This guide teaches deep learning fundamentals through hands-on ExBurn examples. By the end, you'll be able to:
- Create and manipulate tensors (the core data structure of deep learning)
- Build neural network architectures using Axon
- Train models with different optimizers and learning rate strategies
- Prevent overfitting with regularization techniques
- Preprocess real-world data
- Run inference and deploy models
- Write GPU-accelerated numerical functions with
defn
Each lesson builds on the previous one. Code examples are complete and runnable.
Prerequisites
- Elixir ~> 1.18 and OTP 27+
- Rust stable (for NIF compilation)
- Basic Elixir knowledge (modules, functions, pipes)
- No prior deep learning experience required
Add to your mix.exs:
def deps do
[
{:ex_burn, "~> 0.3"},
{:nx, ">= 0.12.0"},
{:axon, "~> 0.8"},
{:ex_cubecl, ">= 0.4.0"}
]
endmix deps.get
mix compile
Check that your GPU is available:
ExBurn.default_device() # :gpu or :cpu
ExBurn.device_name() # e.g. "CUDA (NVIDIA GPU)" or "Metal (Apple GPU)"
ExBurn.summary() # full environment summaryLesson 1: Tensors — The Building Blocks
What is a Tensor?
A tensor is a multi-dimensional array of numbers. Deep learning is essentially tensor math:
| Tensor rank | Example | Shape |
|---|---|---|
| 0 (scalar) | 5.0 | {} |
| 1 (vector) | [1.0, 2.0, 3.0] | {3} |
| 2 (matrix) | [[1, 2], [3, 4]] | {2, 2} |
| 3 (image) | batch of 8 RGB 32x32 images | {8, 3, 32, 32} |
Creating Tensors
import Nx
# From a list
t = Nx.tensor([1.0, 2.0, 3.0])
# 2D tensor (matrix)
m = Nx.tensor([[1.0, 2.0], [3.0, 4.0]])
# With explicit type
t_f64 = Nx.tensor([1.0, 2.0], type: {:f, 64})
t_i32 = Nx.tensor([1, 2, 3], type: {:s, 32})
# Useful constructors
zeros = Nx.broadcast(0.0, {3, 4}) # 3x4 matrix of zeros
ones = Nx.broadcast(1.0, {3, 4}) # 3x4 matrix of ones
iota = Nx.iota({5}) # [0, 1, 2, 3, 4]
eye = Nx.eye(3) # 3x3 identity matrixInspecting Tensors
Nx.shape(t) # {3} — the shape
Nx.type(t) # {:f, 32} — the element type
Nx.rank(t) # 1 — number of dimensions
Nx.size(t) # 3 — total number of elements
Nx.to_list(t) # [1.0, 2.0, 3.0] — convert to Elixir listElement-wise Operations
a = Nx.tensor([1.0, 2.0, 3.0])
b = Nx.tensor([4.0, 5.0, 6.0])
Nx.add(a, b) # [5.0, 7.0, 9.0]
Nx.subtract(a, b) # [-3.0, -3.0, -3.0]
Nx.multiply(a, b) # [4.0, 10.0, 18.0]
Nx.divide(a, b) # [0.25, 0.4, 0.5]
Nx.negate(a) # [-1.0, -2.0, -3.0]
Nx.abs(a) # [1.0, 2.0, 3.0]
Nx.exp(a) # [2.718, 7.389, 20.085]
Nx.log(a) # [0.0, 0.693, 1.099]
Nx.sqrt(a) # [1.0, 1.414, 1.732]Broadcasting
When shapes don't match, Nx automatically broadcasts the smaller tensor:
a = Nx.tensor([[1.0, 2.0], [3.0, 4.0]]) # shape {2, 2}
b = Nx.tensor([10.0, 20.0]) # shape {2}
Nx.add(a, b)
# [[11.0, 22.0],
# [13.0, 24.0]]
# b is broadcast across rowsReductions
Collapse dimensions to produce summaries:
m = Nx.tensor([[1.0, 2.0], [3.0, 4.0]])
Nx.sum(m) # 10.0 — sum all elements
Nx.mean(m) # 2.5 — mean of all elements
Nx.reduce_max(m) # 4.0 — maximum value
Nx.reduce_min(m) # 1.0 — minimum value
# Reduce along a specific axis
Nx.sum(m, axes: [0]) # [4.0, 6.0] — sum along rows (column sums)
Nx.sum(m, axes: [1]) # [3.0, 7.0] — sum along columns (row sums)Shape Manipulation
t = Nx.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
Nx.reshape(t, {2, 3})
# [[1.0, 2.0, 3.0],
# [4.0, 5.0, 6.0]]
m = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
Nx.transpose(m)
# [[1.0, 4.0],
# [2.0, 5.0],
# [3.0, 6.0]]
# Concatenation
a = Nx.tensor([1.0, 2.0])
b = Nx.tensor([3.0, 4.0])
Nx.concatenate([a, b]) # [1.0, 2.0, 3.0, 4.0]Linear Algebra
a = Nx.tensor([[1.0, 2.0], [3.0, 4.0]])
b = Nx.tensor([[5.0, 6.0], [7.0, 8.0]])
Nx.dot(a, b)
# Matrix multiplication:
# [[1*5+2*7, 1*6+2*8],
# [3*5+4*7, 3*6+4*8]]
# = [[19.0, 22.0], [43.0, 50.0]]
# Dot product of vectors
x = Nx.tensor([1.0, 2.0, 3.0])
y = Nx.tensor([4.0, 5.0, 6.0])
Nx.dot(x, y) # 1*4 + 2*5 + 3*6 = 32.0Try It Yourself
# Create a 3x3 matrix, transpose it, then multiply by the original
m = Nx.iota({3, 3}) |> Nx.as_type(:f32)
mt = Nx.transpose(m)
result = Nx.dot(m, mt)
Nx.to_list(result)Lesson 2: Your First Neural Network
What is a Neural Network?
A neural network is a function that transforms input data into predictions through a series of learned transformations:
input → [Linear → Activation] × N → outputEach Linear layer computes output = input × weights + bias. The Activation function introduces non-linearity, enabling the network to learn complex patterns.
Defining a Model with Axon
Axon provides a functional, Keras-like API for building models:
model =
Axon.input("input", shape: {nil, 4})
|> Axon.dense(8, activation: :relu)
|> Axon.dense(3, activation: :softmax)Breaking this down:
Axon.input("input", shape: {nil, 4})— defines the input.nilmeans "any batch size",4means 4 features per sample.Axon.dense(8, activation: :relu)— a fully-connected layer with 8 neurons and ReLU activation.Axon.dense(3, activation: :softmax)— output layer with 3 neurons (one per class) and softmax activation.
Understanding Layer Shapes
# Input: {batch_size, 4}
# ↓ Dense(8) — learns a {4, 8} weight matrix + {8} bias
# Hidden: {batch_size, 8}
# ↓ Dense(3) — learns a {8, 3} weight matrix + {3} bias
# Output: {batch_size, 3}The nil in the input shape is the batch dimension — it can be any size.
Compiling the Model
Before training, we need to compile the model. This initializes parameters and sets up the optimizer:
compiled = ExBurn.Model.compile(model,
loss: :cross_entropy,
optimizer: :adam,
learning_rate: 0.01
)Inspecting the Model
# Keras/PyTorch-style summary
IO.puts(ExBurn.Model.summary(compiled))
# Get model info
info = ExBurn.Model.info(compiled)
IO.puts("Total parameters: #{info.total_params}")
IO.puts("Layers: #{info.layer_count}")
IO.puts("Memory: #{info.estimated_memory_mb} MB")
# Access individual components
ExBurn.Model.parameters(compiled) # parameter map
ExBurn.Model.loss_function(compiled) # :cross_entropy
ExBurn.Model.optimizer(compiled) # :adamForward Pass (Inference)
# Create some dummy input
input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]])
# Run inference
{:ok, output} = ExBurn.Model.predict(compiled, input)
Nx.to_list(output)
# e.g. [[0.2, 0.5, 0.3]] — class probabilities from softmaxActivation Functions
Activation functions introduce non-linearity. Without them, stacking linear layers would be equivalent to a single linear layer:
# Common activations in Axon:
Axon.dense(64, activation: :relu) # ReLU: max(0, x) — most common
Axon.dense(64, activation: :sigmoid) # Sigmoid: 1/(1+e^-x) — outputs in [0,1]
Axon.dense(64, activation: :tanh) # Tanh: outputs in [-1, 1]
Axon.dense(64, activation: :softmax) # Softmax: normalizes to probabilitiesReLU (Rectified Linear Unit) is the default choice for hidden layers. It's simple, fast, and avoids the vanishing gradient problem.
Try It Yourself
# Build a model with 2 hidden layers
model =
Axon.input("x", shape: {nil, 10})
|> Axon.dense(32, activation: :relu, name: "hidden1")
|> Axon.dense(16, activation: :relu, name: "hidden2")
|> Axon.dense(5, name: "output")
compiled = ExBurn.Model.compile(model)
IO.puts(ExBurn.Model.summary(compiled))Lesson 3: Training a Classifier
The Training Loop
Training is the process of adjusting the model's parameters to minimize the loss function. Each iteration:
- Forward pass: Compute predictions from input data
- Loss computation: Measure how wrong the predictions are
- Backward pass: Compute gradients (how to adjust each parameter)
- Optimizer step: Update parameters to reduce loss
Complete Training Example
import Nx
# ── Step 1: Create synthetic data ──────────────────────────
# 100 samples, 4 features, 3 classes
num_samples = 100
num_features = 4
num_classes = 3
# Random features
x = Nx.random_uniform({num_samples, num_features})
# Random integer labels (0, 1, or 2)
y = Nx.random_uniform({num_samples}, type: {:u, 8})
y = Nx.remainder(y, num_classes) |> Nx.as_type({:s, 64})
# ── Step 2: Split into train/validation ────────────────────
{train, val} = ExBurn.Dataset.split({x, y}, val_split: 0.2, shuffle: false)
{train_x, train_y} = train
{val_x, val_y} = val
# ── Step 3: Define the model ──────────────────────────────
model =
Axon.input("input", shape: {nil, num_features})
|> Axon.dense(8, activation: :relu)
|> Axon.dense(num_classes)
# ── Step 4: Compile ───────────────────────────────────────
compiled = ExBurn.Model.compile(model,
loss: :cross_entropy,
optimizer: :adam,
learning_rate: 0.01
)
# ── Step 5: Train ─────────────────────────────────────────
trained = ExBurn.Training.fit(compiled, {train_x, train_y},
epochs: 20,
batch_size: 16,
validation_data: {val_x, val_y},
verbose: true
)
# ── Step 6: Evaluate ──────────────────────────────────────
{loss, accuracy} = ExBurn.Training.evaluate(trained, {val_x, val_y}, true)
IO.puts("Validation loss: #{loss}, accuracy: #{accuracy}")Understanding the Output
When verbose: true, you'll see output like:
Training: 80 samples, 5 batches/epoch, 20 epochs
batch_size=16, effective_batch_size=16, optimizer=adam
Epoch 1: loss=1.0986 (1250 samples/s, 64ms) ETA=1s
Epoch 2: loss=1.0852 (1300 samples/s, 61ms) ETA=1s
...
Epoch 20: loss=0.5234 (1350 samples/s, 59ms)Key metrics:
- loss: The average loss per batch (lower is better)
- samples/s: Training throughput
- ETA: Estimated time remaining
Batch Size
The batch_size controls how many samples are processed before updating parameters:
# Small batch: noisier gradients, slower training, less memory
batch_size: 8
# Large batch: smoother gradients, faster training, more memory
batch_size: 64Epochs
One epoch = one full pass through the training data. More epochs = more training, but too many can cause overfitting.
Try It Yourself
# Experiment: try different batch sizes and learning rates
# Which combination converges fastest?
# Which gives the best final accuracy?Lesson 4: Understanding Loss Functions
What is a Loss Function?
A loss function measures how far the model's predictions are from the true values. Training aims to minimize this value.
Cross-Entropy Loss (Classification)
Used for multi-class classification. Measures the difference between predicted class probabilities and true labels:
# Target as integer class indices
pred = Nx.tensor([[2.0, 1.0, 0.1]]) # model logits for 3 classes
target = Nx.tensor([0]) # true class is 0
# Or target as one-hot encoded
target_onehot = Nx.tensor([[1.0, 0.0, 0.0]])The loss is lower when the model assigns high probability to the correct class:
model = Axon.input("x", shape: {nil, 3}) |> Axon.dense(3)
compiled = ExBurn.Model.compile(model, loss: :cross_entropy)
# Good prediction → low loss
good_pred = Nx.tensor([[10.0, 0.1, 0.1]]) # confident and correct
{:ok, loss} = ExBurn.Model.compute_loss(compiled, good_pred, Nx.tensor([0]))
# loss ≈ 0.0001
# Bad prediction → high loss
bad_pred = Nx.tensor([[0.1, 0.1, 10.0]]) # confident but wrong
{:ok, loss} = ExBurn.Model.compute_loss(compiled, bad_pred, Nx.tensor([0]))
# loss ≈ 10.0Mean Squared Error (Regression)
Used for regression tasks where the target is a continuous value:
model = Axon.input("x", shape: {nil, 5}) |> Axon.dense(1)
compiled = ExBurn.Model.compile(model, loss: :mse)
pred = Nx.tensor([[3.0]])
target = Nx.tensor([[5.0]])
{:ok, loss} = ExBurn.Model.compute_loss(compiled, pred, target)
# MSE = (3-5)² = 4.0Binary Cross-Entropy (Binary Classification)
Used when there are exactly two classes:
model = Axon.input("x", shape: {nil, 10}) |> Axon.dense(1)
compiled = ExBurn.Model.compile(model, loss: :binary_cross_entropy)
# Targets are 0.0 or 1.0
pred = Nx.tensor([[0.9]]) # model predicts class 1 with 90% confidence
target = Nx.tensor([[1.0]]) # true class is 1
{:ok, loss} = ExBurn.Model.compute_loss(compiled, pred, target)
# loss ≈ 0.105 (low, because prediction matches target)Choosing the Right Loss
| Task | Loss Function | Target Format |
|---|---|---|
| Multi-class classification | :cross_entropy | Integer indices or one-hot |
| Binary classification | :binary_cross_entropy | 0.0 or 1.0 |
| Regression | :mse | Continuous values |
Lesson 5: Optimizers and Learning Rates
What is an Optimizer?
An optimizer determines how to update the model's parameters based on the computed gradients. Different optimizers have different strategies.
Adam (Default)
Adam adapts the learning rate for each parameter individually. It's a good default for most tasks:
ExBurn.Model.compile(model,
optimizer: :adam,
learning_rate: 0.001 # good starting point
)When to use: Default choice. Works well with minimal tuning.
Tips:
- If loss oscillates → reduce learning rate (try
0.0001) - If convergence is very slow → increase learning rate (try
0.01)
SGD with Momentum
SGD with momentum accumulates a velocity vector in directions of consistent gradient:
ExBurn.Model.compile(model,
optimizer: :sgd,
learning_rate: 0.01 # needs higher LR than Adam
)
# With Nesterov momentum (often converges faster):
ExBurn.Training.fit(model, data, nesterov: true)When to use: When you need maximum generalization and have time to tune.
RMSprop
RMSprop adapts learning rates based on the magnitude of recent gradients:
ExBurn.Model.compile(model,
optimizer: :rmsprop,
learning_rate: 0.001
)When to use: RNNs, LSTMs, or when Adam diverges.
Learning Rate Schedules
Instead of a fixed learning rate, you can vary it during training:
# Step decay: halve LR every 10 epochs
ExBurn.Training.fit(model, data,
lr_schedule: {:step, 0.001, 10, 0.5}
)
# Exponential decay: multiply LR by 0.95 each epoch
ExBurn.Training.fit(model, data,
lr_schedule: {:exponential, 0.001, 0.95}
)
# Cosine annealing: smooth decay (often best results)
ExBurn.Training.fit(model, data,
lr_schedule: {:cosine, 0.001, 1.0e-5}
)Visual comparison:
LR
│
0.001 ─┤ ████
│ ████ ╲ Step (sudden drops)
│ ████ ╲ ╲
│ ████ ╲ ╲
│ ████ ╲ ╲
0.0001 ┤ ╲ ╲
│ ╲ ╲ ╲
│ ╲ ╲ ╲
│ ╲ ╲ ╲
0.00001 ┤──────────────╲──── Cosine (smooth)
└──────────────────────── EpochsWarmup
Gradually increase the learning rate at the start of training for stability:
ExBurn.Training.fit(model, data,
callbacks: [
ExBurn.Training.WarmupCallback.linear(5, 1.0e-5, 0.001)
]
)This ramps the LR from 1.0e-5 to 0.001 over the first 5 epochs.
Reduce on Plateau
Automatically reduce the learning rate when validation loss stops improving:
ExBurn.Training.fit(model, data,
callbacks: [
ExBurn.Training.ReduceLROnPlateauCallback.new(
patience: 5,
factor: 0.5,
min_lr: 1.0e-6
)
]
)Try It Yourself
# Compare optimizers on the same data:
# 1. Adam with lr=0.001
# 2. SGD with lr=0.01 and nesterov=true
# 3. Adam with cosine annealing
# Which converges fastest? Which gives the best final loss?Lesson 6: Overfitting and Regularization
What is Overfitting?
Overfitting happens when the model memorizes the training data instead of learning general patterns. Signs:
- Training loss keeps decreasing, but validation loss starts increasing
- Large gap between training and validation accuracy
Loss
│
│ ╲ ╱ ── training loss (keeps decreasing)
│ ╲ ╱
│ ╲ ╱
│ ╲ ╱ ╱── validation loss (starts increasing = overfitting!)
│ ╲╱ ╱
│ ╱
└──────────────── EpochsTechnique 1: Dropout
Randomly "drops" (sets to zero) a fraction of neurons during training. Forces the network to not rely on any single neuron:
model =
Axon.input("x", shape: {nil, 10})
|> Axon.dense(64, activation: :relu)
|> Axon.dropout(rate: 0.5) # drop 50% of neurons
|> Axon.dense(64, activation: :relu)
|> Axon.dropout(rate: 0.3) # drop 30% of neurons
|> Axon.dense(3)Rule of thumb: Use rate: 0.2-0.5 for hidden layers. Don't use dropout on the output layer.
Technique 2: Weight Decay (L2 Regularization)
Penalizes large weights, encouraging the model to learn simpler patterns:
ExBurn.Model.compile(model,
weight_decay: 1.0e-4 # L2 regularization coefficient
)Rule of thumb:
1.0e-4— good default1.0e-5— small datasets (less regularization needed)1.0e-3— large models that overfit
Technique 3: Early Stopping
Stop training when validation loss stops improving:
ExBurn.Training.fit(model, data,
validation_data: val_data,
callbacks: [
ExBurn.Training.EarlyStoppingCallback.wait(5, 1.0e-4)
]
)This stops training after 5 epochs without at least 1.0e-4 improvement in validation loss.
Technique 4: Gradient Clipping
Prevents exploding gradients (which cause NaN loss):
ExBurn.Training.fit(model, data,
clip_norm: 1.0, # clip gradient norm to 1.0
clip_value: 5.0 # also clip individual gradient values to [-5, 5]
)Technique 5: Freezing Layers
When fine-tuning a pre-trained model, freeze early layers to preserve learned features:
# Freeze the first layer
frozen_model = ExBurn.Model.freeze(model, ["hidden1"])
# Check which layers are frozen
ExBurn.Model.frozen_layers(frozen_model) # #MapSet<["hidden1"]>
# Unfreeze later
unfrozen_model = ExBurn.Model.unfreeze(frozen_model, ["hidden1"])Try It Yourself
# Train a model WITHOUT regularization → observe overfitting
# Then add dropout + weight decay + early stopping → compareLesson 7: Working with Real Data
Data Splitting
Always split your data into training, validation, and test sets:
# Split into 80% train, 20% validation
{train, val} = ExBurn.Dataset.split({x, y}, val_split: 0.2, shuffle: true, seed: 42)
# For a three-way split:
{train, temp} = ExBurn.Dataset.split({x, y}, val_split: 0.3, seed: 42)
{val, test} = ExBurn.Dataset.split(temp, val_split: 0.5, seed: 42)
# Result: 70% train, 15% val, 15% testUse seed for reproducible splits.
Data Loading
Create a batched data loader for efficient training:
loader = ExBurn.Dataset.loader({x, y},
batch_size: 32,
shuffle: true,
drop_last: false # keep partial last batch
)
# Iterate through batches
Enum.each(loader, fn {batch_x, batch_y} ->
# process batch
end)Normalization
Neural networks train better when input features are on a similar scale:
# Standard normalization: zero mean, unit variance
{train_norm, stats} = ExBurn.Dataset.normalize(train_x, method: :standard)
# Apply the same transformation to validation/test data
val_norm = ExBurn.Dataset.normalize_with_stats(val_x, stats)Three normalization methods:
| Method | What it does | When to use |
|---|---|---|
:standard | (x - mean) / std | Default for most features |
:minmax | (x - min) / (max - min) | When you need values in [0, 1] |
| :l2 | x / ||x||_2 | When direction matters more than magnitude |
Important: Always compute statistics on training data only, then apply them to validation/test data.
One-Hot Encoding
Convert integer class labels to one-hot vectors:
labels = Nx.tensor([0, 2, 1, 3])
one_hot = ExBurn.Dataset.one_hot(labels, num_classes: 4)
# [[1, 0, 0, 0],
# [0, 0, 1, 0],
# [0, 1, 0, 0],
# [0, 0, 0, 1]]Dataset Statistics
stats = ExBurn.Dataset.stats({x, y})
# %{num_samples: 100, input_shape: {100, 4}, target_shape: {100},
# input_type: {:f, 32}, target_type: {:s, 64}}Complete Data Pipeline Example
# 1. Load your data (however you get it)
# x = ... # your features
# y = ... # your labels
# 2. Split
{train, val} = ExBurn.Dataset.split({x, y}, val_split: 0.2, seed: 42)
# 3. Normalize
{train_x_norm, norm_stats} = ExBurn.Dataset.normalize(elem(train, 0), method: :standard)
val_x_norm = ExBurn.Dataset.normalize_with_stats(elem(val, 0), norm_stats)
# 4. Train
{ExBurn.Model.compile(model), {train_x_norm, elem(train, 1)}}
|> then(fn {compiled, train_data} ->
ExBurn.Training.fit(compiled, train_data,
validation_data: {val_x_norm, elem(val, 1)},
epochs: 50
)
end)Lesson 8: Inference and Deployment
Running Inference
After training, use the model to make predictions:
# Single prediction
input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]])
{:ok, output} = ExBurn.Model.predict(trained_model, input)
Nx.argmax(output) # predicted class
# Batch prediction
batch = Nx.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])
{:ok, outputs} = ExBurn.Model.predict(trained_model, batch)GPU vs CPU Inference
# GPU inference (via defn compiler)
{:ok, output} = ExBurn.Model.forward(trained_model, input)
# CPU inference (via Axon predict)
{:ok, output} = ExBurn.Model.predict(trained_model, input)Batched Concurrent Inference with Serving
For production use, Nx.Serving handles concurrent batching:
serving = ExBurn.Serving.build(trained_model,
batch_size: 32,
batch_timeout: 50,
partitions: System.schedulers_online()
)
# Run inference
output = Nx.Serving.run(serving, input)Saving and Loading Models
# Save to file
ExBurn.Model.save(trained_model, "my_model.bin")
# Load from file
{:ok, loaded_model} = ExBurn.Model.load(compiled_model, "my_model.bin")
# Serialize to binary (for network transfer)
binary = ExBurn.Model.serialize_params(trained_model)
{:ok, params} = ExBurn.Model.deserialize_params(binary)Export Formats
# Compressed Erlang terms (default, portable)
ExBurn.Model.export(model, "model.etf", format: :elixir_terms)
# JSON (human-readable, larger)
ExBurn.Model.export(model, "model.json", format: :json)
# Import
{:ok, model} = ExBurn.Model.import_params(model, "model.etf")
{:ok, model} = ExBurn.Model.import_params(model, "model.json", format: :json)Model Quantization
Reduce model size for deployment:
# Convert to half precision (f16) — 2x smaller
quantized = ExBurn.Model.quantize(trained_model, :f16)
# Or brain float 16 (bf16) — better range than f16
quantized = ExBurn.Model.quantize(trained_model, :bf16)Benchmarking
Measure inference speed:
results = ExBurn.Model.benchmark(trained_model, input, warmup: 3, runs: 10)
# %{avg_ms: 1.234, min_ms: 1.100, max_ms: 1.500,
# median_ms: 1.200, std_ms: 0.120, runs: 10, warmup: 3}Lesson 9: GPU-Accelerated Numerical Functions
What is defn?
defn lets you write numerical functions that run on the GPU. The ExBurn.Defn.Compiler traces your function and compiles it to Burn GPU kernels.
Setup
Nx.default_backend(ExBurn.Backend)
Nx.Defn.global_default_options(compiler: ExBurn.Defn.Compiler)Writing defn Functions
defmodule MyMath do
import Nx.Defn
# Element-wise sigmoid: 1 / (1 + e^(-x))
defn sigmoid(x) do
Nx.divide(1.0, Nx.add(1.0, Nx.exp(Nx.negate(x))))
end
# Linear regression prediction: X @ w + b
defn predict(X, w, b) do
Nx.add(Nx.dot(X, w), b)
end
# Mean squared error
defn mse_loss(y_true, y_pred) do
diff = Nx.subtract(y_true, y_pred)
Nx.mean(Nx.multiply(diff, diff))
end
# ReLU activation
defn relu(x) do
Nx.max(x, 0.0)
end
# L2 normalization
defn l2_normalize(x) do
norm = Nx.sqrt(Nx.sum(Nx.multiply(x, x), axes: [-1], keep_axes: true))
Nx.divide(x, norm)
end
end
# These all run on the GPU!
MyMath.sigmoid(Nx.tensor([1.0, 2.0, 3.0]))
MyMath.relu(Nx.tensor([-1.0, 0.0, 1.0]))Per-Function Compiler Override
defmodule MyModule do
import Nx.Defn
# This function uses ExBurn's GPU compiler
defn gpu_function(x) do
Nx.sin(x) |> Nx.exp()
end
compiler: ExBurn.Defn.Compiler
# This function uses the default (CPU) compiler
defn cpu_function(x) do
Nx.cos(x)
end
endControl Flow in defn
defmodule ControlFlow do
import Nx.Defn
defn clip_and_scale(x, min_val, max_val, scale) do
x
|> Nx.clip(min_val, max_val)
|> Nx.multiply(scale)
end
defn conditional_compute(x, threshold) do
# Use Nx.select for conditional operations
Nx.select(
Nx.greater(x, threshold), # condition
Nx.multiply(x, 2.0), # value when true
Nx.divide(x, 2.0) # value when false
)
end
endUsing BurnBridge Directly
For maximum performance, bypass Nx and talk to Burn directly:
# Create tensors directly on the GPU
t1 = ExBurn.BurnBridge.zeros([100, 100], :f32)
t2 = ExBurn.BurnBridge.ones([100, 100], :f32)
# Each operation is a single NIF call
t3 = ExBurn.BurnBridge.add(t1, t2)
t4 = ExBurn.BurnBridge.matmul(t1, t2)
t5 = ExBurn.BurnBridge.relu(t3)
# Convert back to Nx when needed
nx_tensor = ExBurn.BurnBridge.to_nx(t3)Try It Yourself
# Implement a GPU-accelerated softmax function using defn
defmodule SoftmaxGPU do
import Nx.Defn
defn softmax(x) do
# Numerically stable softmax
shifted = x - Nx.reduce_max(x, axes: [-1], keep_axes: true)
exp_shifted = Nx.exp(shifted)
exp_shifted / Nx.sum(exp_shifted, axes: [-1], keep_axes: true)
end
end
# Test it
input = Nx.tensor([[1.0, 2.0, 3.0]])
SoftmaxGPU.softmax(input)
# Should sum to 1.0 across the last dimensionLesson 10: Putting It All Together
Complete Example: Iris-like Classification
This example combines everything from the previous lessons:
import Nx
# ── 1. Prepare Data ────────────────────────────────────────
num_samples = 150
num_features = 4
num_classes = 3
# Synthetic data (replace with real data in practice)
x = Nx.random_uniform({num_samples, num_features})
y = Nx.remainder(Nx.iota({num_samples}), num_classes) |> Nx.as_type({:s, 64})
# Split
{train, val} = ExBurn.Dataset.split({x, y}, val_split: 0.2, seed: 42)
{train_x, train_y} = train
{val_x, val_y} = val
# Normalize
{train_x_norm, stats} = ExBurn.Dataset.normalize(train_x, method: :standard)
val_x_norm = ExBurn.Dataset.normalize_with_stats(val_x, stats)
# ── 2. Define Model ────────────────────────────────────────
model =
Axon.input("features", shape: {nil, num_features})
|> Axon.dense(32, activation: :relu, name: "hidden1")
|> Axon.dropout(rate: 0.2)
|> Axon.dense(16, activation: :relu, name: "hidden2")
|> Axon.dropout(rate: 0.2)
|> Axon.dense(num_classes, name: "output")
# ── 3. Compile ─────────────────────────────────────────────
compiled = ExBurn.Model.compile(model,
loss: :cross_entropy,
optimizer: :adam,
learning_rate: 0.001,
weight_decay: 1.0e-4
)
IO.puts(ExBurn.Model.summary(compiled))
# ── 4. Train ───────────────────────────────────────────────
trained = ExBurn.Training.fit(compiled,
{train_x_norm, train_y},
epochs: 50,
batch_size: 16,
shuffle: true,
validation_data: {val_x_norm, val_y},
lr_schedule: {:cosine, 0.001, 1.0e-5},
clip_norm: 1.0,
accuracy: true,
callbacks: [
&ExBurn.Training.LoggingCallback.log/1,
ExBurn.Training.EarlyStoppingCallback.wait(10, 1.0e-5),
ExBurn.Training.HistoryCallback.new()
],
verbose: true
)
# ── 5. Evaluate ────────────────────────────────────────────
{loss, accuracy} = ExBurn.Training.evaluate(trained, {val_x_norm, val_y}, true)
IO.puts("Final — loss: #{Float.round(loss, 4)}, accuracy: #{Float.round(accuracy * 100, 1)}%")
# ── 6. Inference ──────────────────────────────────────────
new_sample = Nx.tensor([[5.1, 3.5, 1.4, 0.2]])
new_sample_norm = ExBurn.Dataset.normalize_with_stats(new_sample, stats)
{:ok, prediction} = ExBurn.Model.predict(trained, new_sample_norm)
predicted_class = Nx.argmax(prediction) |> Nx.to_number()
IO.puts("Predicted class: #{predicted_class}")
# ── 7. Save ───────────────────────────────────────────────
ExBurn.Model.save(trained, "iris_model.bin")
IO.puts("Model saved!")Training Checklist
Use this checklist for every training run:
- [ ] Data split: Train/val/test split with a fixed seed
- [ ] Normalization: Fit on training data, transform all splits
- [ ] Model architecture: Appropriate depth/width for the problem
- [ ] Loss function: Matches the task (classification vs regression)
- [ ] ] Optimizer: Start with Adam, lr=0.001
- [ ] Regularization: Dropout + weight decay to prevent overfitting
- [ ] Early stopping: Stop when validation loss plateaus
- [ ] Gradient clipping: Enable if you see NaN loss
- [ ] Learning rate schedule: Cosine annealing for best results
- [ ] Evaluation: Check both loss and accuracy on validation set
Common Problems and Solutions
| Problem | Likely Cause | Solution |
|---|---|---|
| Loss is NaN | Exploding gradients | Enable clip_norm: 1.0, reduce learning rate |
| Loss doesn't decrease | LR too low, wrong loss | Increase LR, check loss function |
| Loss oscillates | LR too high, batch too small | Reduce LR, increase batch size or use accumulate_gradients |
| Overfitting | Model too complex | Add dropout, weight decay, early stopping |
| Training very slow | Large model with numerical gradients | Use grad_method: :numerical_batch, reduce model size |
Next Steps
- Training Models — Full API reference for training
- Training Optimization Guide — Advanced tuning techniques
- Mobile Deployment — Deploy to iOS/Android
- Architecture Deep-Dive — How ExBurn works internally