Dala integrates ExBurn, a bridge between Nx and the Burn deep learning framework (Rust). This enables GPU-accelerated ML/DL training and inference on iOS, Android, and desktop.
Status
Early alpha — basic ops and inference work. Training uses numerical gradients. Burn's autodiff integration is planned for a future release.
Architecture
Axon model
↓
Nx.Defn graph
↓
ExBurn.Backend (Nx.Backend behaviour)
↓
ExBurn.Nif (Rustler NIF) ←→ ExCubecl (GPU buffers, kernels, pipelines)
↓
Burn Autodiff<CubeCL> (Rust)
↓
CubeCL kernels
↓
Metal (iOS) / Vulkan (Android) / CUDA → GPUGPU Backends
| Platform | Backend | Status |
|---|---|---|
| iOS | Metal | ✅ |
| Android | Vulkan | ✅ |
| macOS | Metal | ✅ |
| Linux | Vulkan | ✅ |
| NVIDIA | CUDA | 🔜 |
Quick Start
1. Check Availability
# Is ExBurn loaded?
Dala.ML.Burn.available?()
# true
# Is a GPU available?
Dala.ML.Burn.gpu?()
# true on iOS/Android with GPU support
# What device will be used?
Dala.ML.Burn.default_device()
# :gpu or :cpu2. Configure
# Set ExBurn as the default Nx backend
Dala.ML.Burn.configure!()
# Or with options
Dala.ML.Burn.configure!(device: :gpu)Dala.ML.setup/0 auto-configures Burn when available — no manual setup needed in most cases.
3. Tensors via Burn
# All Nx operations now run through Burn
t = Nx.tensor([1.0, 2.0, 3.0])
Nx.add(t, t) |> Nx.to_list()
# [2.0, 4.0, 6.0]
# Direct Burn tensor creation (bypasses Nx for performance)
bt = Dala.ML.Burn.zeros([3, 3], :f32)
bt = Dala.ML.Burn.ones([2, 4], :f32)
# Convert between Nx and Burn
{:ok, bt} = Dala.ML.Burn.from_nx(tensor)
{:ok, tensor} = Dala.ML.Burn.to_nx(bt)4. Define and Compile a Model
model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(256, activation: :relu)
|> Axon.dropout(rate: 0.2)
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10)
compiled = Dala.ML.Burn.compile(model,
loss: :cross_entropy,
optimizer: :adam,
learning_rate: 0.001
)5. Train
trained = Dala.ML.Burn.fit(compiled, {train_x, train_y},
epochs: 10,
batch_size: 32,
validation_data: {val_x, val_y}
)6. Inference
{:ok, predictions} = Dala.ML.Burn.predict(trained, input_tensor)7. Save / Load
:ok = Dala.ML.Burn.save(trained, "my_model.model")
{:ok, loaded} = Dala.ML.Burn.load(trained, "my_model.model")Training
Basic Training
model = Dala.ML.Burn.compile(axon_model,
loss: :cross_entropy,
optimizer: :adam,
learning_rate: 0.001
)
trained = Dala.ML.Burn.fit(model, {inputs, targets},
epochs: 10,
batch_size: 32
)Training with Validation
trained = Dala.ML.Burn.fit(model, {train_x, train_y},
epochs: 50,
batch_size: 64,
validation_data: {val_x, val_y},
verbose: true
)Training with Callbacks
callbacks = [
# Log metrics after each epoch
Dala.ML.Burn.Training.logging_callback(),
# Stop if val loss doesn't improve for 5 epochs
Dala.ML.Burn.Training.early_stopping_callback(5, 1.0e-4),
# Save checkpoint every 10 epochs
Dala.ML.Burn.Training.checkpoint_callback(10, "checkpoints/"),
# Report progress to a LiveView screen via handle_info
Dala.ML.Burn.Training.screen_callback(self())
]
trained = Dala.ML.Burn.fit(model, {train_x, train_y},
epochs: 100,
batch_size: 32,
validation_data: {val_x, val_y},
callbacks: callbacks
)Handle screen progress updates in your LiveView or GenServer:
def handle_info({:training_progress, epoch, loss, val_loss}, socket) do
{:noreply, assign(socket,
epoch: epoch,
loss: loss,
val_loss: val_loss
)}
endTraining with History Tracking
{trained, history} = Dala.ML.Burn.Training.fit_with_progress(
model, {train_x, train_y},
epochs: 50,
batch_size: 32,
validation_data: {val_x, val_y}
)
# history => [%{epoch: 1, loss: 0.5, val_loss: 0.4}, ...]Learning Rate Schedules
# Step decay: halve LR every 10 epochs
Dala.ML.Burn.fit(model, data,
lr_schedule: {:step, 0.001, 10, 0.5}
)
# Exponential decay
Dala.ML.Burn.fit(model, data,
lr_schedule: {:exponential, 0.001, 0.95}
)
# Cosine annealing
Dala.ML.Burn.fit(model, data,
lr_schedule: {:cosine, 0.001, 1.0e-5}
)Gradient Clipping
Dala.ML.Burn.fit(model, data,
clip_norm: 1.0, # Clip by max norm
clip_value: 0.5 # Clip by max absolute value
)Loss Functions
Supported loss functions:
| Loss | Description |
|---|---|
:cross_entropy | Categorical cross-entropy (with log-softmax stability) |
:mse | Mean squared error |
:binary_cross_entropy | Binary cross-entropy (with numerical clamping) |
Optimizers
| Optimizer | Options |
|---|---|
:adam | beta1: 0.9, beta2: 0.999, epsilon: 1.0e-8 |
:sgd | momentum: 0.9 |
:rmsprop | decay: 0.9, epsilon: 1.0e-8 |
Evaluation
avg_loss = Dala.ML.Burn.evaluate(model, {test_x, test_y})
# 0.234Model Summary
IO.puts(Dala.ML.Burn.summary(model))
# ╔══════════════════════════════════════════════════════════╗
# ║ ExBurn Model Summary ║
# ╠══════════════════════════════════════════════════════════╣
# ║ Total params: 235146 ║
# ║ Trainable params: 235146 ║
# ║ Non-trainable: 0 ║
# ║ Formatted: 235.1K ║
# ╠══════════════════════════════════════════════════════════╣Serving (Production Inference)
For production use, wrap your model in an Nx.Serving for batched, concurrent inference:
# Build a serving
serving = Dala.ML.Burn.Serving.build(trained_model,
batch_size: 16,
batch_timeout: 100
)
# Run single inference
output = Dala.ML.Burn.Serving.run(serving, input_tensor)
# Or supervise it in your app tree
children = [
{Nx.Serving,
serving: Dala.ML.Burn.Serving.build(trained_model, batch_size: 32),
name: :my_model_serving}
]
# Then use it from anywhere
output = Nx.Serving.run(:my_model_serving, input_tensor)Unified API
Dala.ML.predict/2 dispatches to Burn when given an ExBurn.Model:
# CoreML model (string identifier on iOS)
Dala.ML.predict("my_model", %{"input" => [1.0, 2.0]})
# ONNX session (integer session ID)
Dala.ML.predict(session_id, input_binary)
# Axon model ({model, params} tuple)
Dala.ML.predict({axon_model, params}, input_tensor)
# ExBurn model (ExBurn.Model struct)
Dala.ML.predict(exburn_model, input_tensor)Benchmarking
# Benchmark current backend
Dala.ML.benchmark(size: 100, iterations: 10)
# %{
# time_ms: 1.234,
# gflops: 0.857,
# backend: {EMLX.Backend, [device: :gpu]},
# burn: %{time_ms: 0.567, gflops: 1.234} # if ExBurn available
# }Platform Notes
iOS
- Uses Metal GPU backend via Burn's CubeCL
- No JIT required (unlike EMLX on devices)
- Training small models (< 10M params) is feasible
- Inference is the primary use case
Android
- Uses Vulkan GPU backend via Burn's CubeCL
- Same training/inference capabilities as iOS
Desktop (Development)
- Uses Metal (macOS) or Vulkan (Linux)
- CUDA support planned
Training on Mobile — Caveats
Burn's Autodiff backend is memory-intensive. On iOS/Android with limited RAM:
- Fine-tuning small models (< 10M parameters) is feasible on modern devices
- Full training of large models is not recommended on mobile
- Inference is the primary use case for mobile deployment
- Minimum recommended: 4GB RAM, A12+ chip (iOS) / Snapdragon 700+ (Android)
The training loop in ExBurn currently uses numerical gradients. Burn's autodiff integration is planned for a future release.
Comparison with Other Dala ML Backends
| Backend | Best For | GPU | Training | iOS | Android |
|---|---|---|---|---|---|
| EMLX | iOS inference | Metal (MLX) | ✅ | ✅ | ❌ |
| CoreML | iOS Neural Engine | ANE | ❌ | ✅ | ❌ |
| ONNX | Cross-platform | NNAPI/CoreML | ❌ | ✅ | ✅ |
| GPU Compute | Custom kernels | CubeCL | N/A | ✅ | ✅ |
| ExBurn | Training + inference | CubeCL | ✅ | ✅ | ✅ |
Error Handling
All operations raise ExBurn.Error with structured context:
raise ExBurn.Error,
op: :matmul,
reason: "shape mismatch",
details: %{lhs: [3, 4], rhs: [5, 6]}Troubleshooting
available?() returns false
- Ensure
ex_burnis in your deps:{:ex_burn, "~> 0.1"} - Run
mix deps.get && mix compile - The Rust NIF will be compiled automatically via Rustler
gpu?() returns false
- ExBurn checks ExCubecl availability for GPU detection
- On iOS/Android, ensure the GPU compute libraries are linked
- On desktop, GPU may not be available — falls back to CPU
Training is slow
- Current training uses numerical gradients (finite differences)
- For faster training, use EMLX or cloud training and deploy to device
- Reduce batch size and model size for mobile
Out of memory during training
- Reduce batch size
- Use smaller models
- Call
Dala.ML.Burn.free/1on intermediate tensors when done