ExBurn (Burn) Integration Guide

Copy Markdown View Source

Dala integrates ExBurn, a bridge between Nx and the Burn deep learning framework (Rust). This enables GPU-accelerated ML/DL training and inference on iOS, Android, and desktop.

Status

v0.4.0 — Full Nx backend, defn compiler, training loop, serving, model management. Key improvements:

  • Autodiff gradients: Burn Autodiff backward pass for exact gradient computation (with numerical fallback)
  • GPU forward pass: Nx.Defn.jit_apply + ExBurn.Defn.Compiler for GPU-accelerated inference
  • Glorot/Xavier initialization: Proper weight initialization for all model parameters
  • Model summary: Keras/PyTorch-style layer-by-layer summary
  • Layer freeze/unfreeze: For fine-tuning workflows
  • Weight decay: L2 regularization support
  • Gradient accumulation: Effective larger batch sizes
  • Nesterov momentum: For SGD optimizer
  • Accuracy tracking: Classification accuracy during training
  • Profile step: Detailed timing for forward/backward/optimizer phases
  • New callbacks: Warmup, ReduceLROnPlateau, History
  • Improved numerical gradients: :numerical_batch method (~2x faster)

Architecture

Axon model
   
Nx.Defn graph
   
ExBurn.Defn.Compiler (Nx.Defn.Compiler behaviour)
   
ExBurn.Backend (Nx.Backend behaviour)
   
ExBurn.Nif (Rustler NIF)  ExCubecl (GPU buffers, kernels, pipelines)
   
Burn Autodiff<CubeCL> (Rust)
   
CubeCL kernels
   
Metal (iOS) / Vulkan (Android) / CUDA  GPU

GPU Backends

PlatformBackendStatus
iOSMetal
AndroidVulkan
macOSMetal
LinuxVulkan
NVIDIACUDA

Quick Start

1. Check Availability

# Is ExBurn loaded?
Dala.ML.Burn.available?()
# true

# Is the NIF library responding?
Dala.ML.Burn.nif_loaded?()
# true

# Is a GPU available?
Dala.ML.Burn.gpu?()
# true on iOS/Android with GPU support

# What device will be used?
Dala.ML.Burn.default_device()
# :gpu or :cpu

# Device name
Dala.ML.Burn.device_name()
# "Metal (Apple GPU)" | "CUDA (NVIDIA GPU)" | "NdArray (CPU)"

# Available backends
Dala.ML.Burn.available_backends()
# [:metal] | [:cuda] | [:vulkan]

# Quick smoke test
Dala.ML.Burn.smoke_test()
# :ok

# Full environment summary
IO.puts(Dala.ML.Burn.summary())

2. Configure

# Set ExBurn as the default Nx backend
Dala.ML.Burn.configure!()

# Or with options
Dala.ML.Burn.configure!(device: :gpu)

Dala.ML.setup/0 auto-configures Burn when available — no manual setup needed in most cases.

3. Tensors via Burn

# All Nx operations now run through Burn
t = Nx.tensor([1.0, 2.0, 3.0])
Nx.add(t, t) |> Nx.to_list()
# [2.0, 4.0, 6.0]

# Direct Burn tensor creation (bypasses Nx for performance)
bt = Dala.ML.Burn.zeros([3, 3], :f32)
bt = Dala.ML.Burn.ones([2, 4], :f32)
bt = Dala.ML.Burn.rand([2, 4], :f32, 0.0, 1.0)

# Convert between Nx and Burn
{:ok, bt} = Dala.ML.Burn.from_nx(tensor)
{:ok, tensor} = Dala.ML.Burn.to_nx(bt)

# Batch convert
{:ok, bts} = Dala.ML.Burn.from_nx_batch([t1, t2, t3])
{:ok, tensors} = Dala.ML.Burn.to_nx_batch(bts)

# Tensor inspection
Dala.ML.Burn.tensor_shape(bt)   # [3, 3]
Dala.ML.Burn.tensor_type(bt)    # :f32
Dala.ML.Burn.tensor_numel(bt)   # 9
Dala.ML.Burn.tensor_rank(bt)    # 2

# Direct Burn tensor operations (no Nx overhead)
bt2 = Dala.ML.Burn.add(bt, bt)
bt2 = Dala.ML.Burn.matmul(bt, bt)
bt2 = Dala.ML.Burn.relu(bt)
bt2 = Dala.ML.Burn.softmax(bt)

# Device transfer
bt_gpu = Dala.ML.Burn.to_gpu(bt)
bt_cpu = Dala.ML.Burn.to_cpu(bt_gpu)

4. Define and Compile a Model

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(256, activation: :relu)
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10)

compiled = Dala.ML.Burn.compile(model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.001
)

5. Train

trained = Dala.ML.Burn.fit(compiled, {train_x, train_y},
  epochs: 10,
  batch_size: 32,
  validation_data: {val_x, val_y}
)

6. Inference

{:ok, predictions} = Dala.ML.Burn.predict(trained, input_tensor)

7. Save / Load

:ok = Dala.ML.Burn.save(trained, "my_model.model")
{:ok, loaded} = Dala.ML.Burn.load(trained, "my_model.model")

Training

Basic Training

model = Dala.ML.Burn.compile(axon_model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.001
)

trained = Dala.ML.Burn.fit(model, {inputs, targets},
  epochs: 10,
  batch_size: 32
)

Training with Validation

trained = Dala.ML.Burn.fit(model, {train_x, train_y},
  epochs: 50,
  batch_size: 64,
  validation_data: {val_x, val_y},
  verbose: true
)

Training with Callbacks

callbacks = [
  # Log metrics after each epoch
  Dala.ML.Burn.Training.logging_callback(),

  # Stop if val loss doesn't improve for 5 epochs
  Dala.ML.Burn.Training.early_stopping_callback(5, 1.0e-4),

  # Save checkpoint every 10 epochs
  Dala.ML.Burn.Training.checkpoint_callback(10, "checkpoints/"),

  # Learning rate warmup over 3 epochs
  Dala.ML.Burn.Training.warmup_callback(3, 1.0e-5, 0.001),

  # Reduce LR on plateau
  Dala.ML.Burn.Training.reduce_on_plateau_callback(patience: 3, factor: 0.5),

  # Report progress to a LiveView screen via handle_info
  Dala.ML.Burn.Training.screen_callback(self())
]

trained = Dala.ML.Burn.fit(model, {train_x, train_y},
  epochs: 100,
  batch_size: 32,
  validation_data: {val_x, val_y},
  callbacks: callbacks,
  accuracy: true
)

Standard Callbacks Helper

# Quick setup with sensible defaults
callbacks = Dala.ML.Burn.Training.standard_callbacks(
  early_stopping_patience: 5,
  checkpoint_interval: 10,
  checkpoint_dir: "checkpoints",
  warmup_epochs: 3,
  learning_rate: 0.001
)

trained = Dala.ML.Burn.fit(model, {train_x, train_y},
  epochs: 50,
  batch_size: 32,
  validation_data: {val_x, val_y},
  callbacks: callbacks
)

Handle screen progress updates in your LiveView or GenServer:

def handle_info({:training_progress, epoch, loss, val_loss}, socket) do
  {:noreply, assign(socket,
    epoch: epoch,
    loss: loss,
    val_loss: val_loss
  )}
end

Training with History Tracking

{trained, history} = Dala.ML.Burn.Training.fit_with_progress(
  model, {train_x, train_y},
  epochs: 50,
  batch_size: 32,
  validation_data: {val_x, val_y}
)

# history => [%{epoch: 1, loss: 0.5, val_loss: 0.4}, ...]

Learning Rate Schedules

# Step decay: halve LR every 10 epochs
Dala.ML.Burn.fit(model, data,
  lr_schedule: {:step, 0.001, 10, 0.5}
)

# Exponential decay
Dala.ML.Burn.fit(model, data,
  lr_schedule: {:exponential, 0.001, 0.95}
)

# Cosine annealing
Dala.ML.Burn.fit(model, data,
  lr_schedule: {:cosine, 0.001, 1.0e-5}
)

Gradient Clipping

Dala.ML.Burn.fit(model, data,
  clip_norm: 1.0,    # Clip by max norm
  clip_value: 0.5    # Clip by max absolute value
)

Loss Functions

Supported loss functions:

LossDescription
:cross_entropyCategorical cross-entropy (with log-softmax stability)
:mseMean squared error
:binary_cross_entropyBinary cross-entropy (with numerical clamping)

Optimizers

OptimizerOptions
:adambeta1: 0.9, beta2: 0.999, epsilon: 1.0e-8
:sgdmomentum: 0.9
:rmspropdecay: 0.9, epsilon: 1.0e-8

Evaluation

avg_loss = Dala.ML.Burn.evaluate(model, {test_x, test_y})
# 0.234

Model Summary

IO.puts(Dala.ML.Burn.summary(model))
# ╔══════════════════════════════════════════════════════════╗
# ║                   ExBurn Model Summary                  ║
# ╠══════════════════════════════════════════════════════════╣
# ║  Total params:                                    235146 ║
# ║  Trainable params:                                235146 ║
# ║  Non-trainable:                                        0 ║
# ║  Formatted:                                      235.1K ║
# ╠══════════════════════════════════════════════════════════╣

Model Management

Quantization

Reduce model size and speed up inference by quantizing to lower precision:

# Quantize to f16 (half precision)
quantized = Dala.ML.Burn.quantize(model, :f16)

# Quantize to bf16 (brain float 16)
quantized = Dala.ML.Burn.quantize(model, :bf16)

Export / Import

Export models to portable formats:

# Compressed Erlang term format (default, portable)
Dala.ML.Burn.export(model, "model.etf")
{:ok, model} = Dala.ML.Burn.import_params(model, "model.etf")

# JSON format (human-readable, larger)
Dala.ML.Burn.export(model, "model.json", format: :json)
{:ok, model} = Dala.ML.Burn.import_params(model, "model.json", format: :json)

Layer Freezing

Freeze layers for fine-tuning:

# Freeze specific layers
frozen = Dala.ML.Burn.freeze(model, ["dense_0", "dense_1"])

# Check if a layer is frozen
Dala.ML.Burn.frozen?(frozen, "dense_0")  # true

# Unfreeze layers
unfrozen = Dala.ML.Burn.unfreeze(frozen, ["dense_0"])

Model Info & Benchmarking

# Detailed model information
info = Dala.ML.Burn.info(model)
# %{total_params: 235146, layer_count: 4, device: :gpu, estimated_memory_mb: 0.89, ...}

# Benchmark forward pass
result = Dala.ML.Burn.benchmark(model, input, warmup: 5, runs: 20)
# %{avg_ms: 1.234, min_ms: 1.100, max_ms: 1.500, median_ms: 1.200, std_ms: 0.089}

# Clone a model
snapshot = Dala.ML.Burn.clone(model)

GPU-Accelerated Defn

Enable the ExBurn defn compiler for custom GPU kernels via Nx.Defn:

Dala.ML.Burn.enable_defn_compiler!()

defmodule MyKernels do
  import Nx.Defn

  defn add_and_scale(x, y, scale) do
    x |> Nx.add(y) |> Nx.multiply(scale)
  end
end

# Runs on GPU via Burn
MyKernels.add_and_scale(Nx.tensor([1.0]), Nx.tensor([2.0]), Nx.tensor(3.0))

Error Handling

# Create error structs
err = Dala.ML.Burn.error(op: :forward, reason: "shape mismatch")

# Wrap error tuples
err = Dala.ML.Burn.error_from_tuple({:error, "failed"}, op: :predict)

# Format for logging
Dala.ML.Burn.format_error(err)
# "ExBurn.forward: shape mismatch"

Dala.ML.Burn.error_to_log_string(err)
# "[ExBurn:forward] shape mismatch"

Serving (Production Inference)

For production use, wrap your model in an Nx.Serving for batched, concurrent inference:

# Build a serving
serving = Dala.ML.Burn.Serving.build(trained_model,
  batch_size: 16,
  batch_timeout: 100
)

# Run single inference
output = Dala.ML.Burn.Serving.run(serving, input_tensor)

# Or supervise it in your app tree
children = [
  {Nx.Serving,
   serving: Dala.ML.Burn.Serving.build(trained_model, batch_size: 32),
   name: :my_model_serving}
]

# Or use the convenience helper
{:ok, _pid} = Dala.ML.Burn.Serving.supervise(trained_model,
  name: :my_model_serving,
  supervisor: MyApp.DynamicSupervisor
)

# Then use it from anywhere
output = Nx.Serving.run(:my_model_serving, input_tensor)

Unified API

Dala.ML.predict/2 dispatches to Burn when given an ExBurn.Model:

# CoreML model (string identifier on iOS)
Dala.ML.predict("my_model", %{"input" => [1.0, 2.0]})

# ONNX session (integer session ID)
Dala.ML.predict(session_id, input_binary)

# Axon model ({model, params} tuple)
Dala.ML.predict({axon_model, params}, input_tensor)

# ExBurn model (ExBurn.Model struct)
Dala.ML.predict(exburn_model, input_tensor)

Benchmarking

# Benchmark current backend
Dala.ML.benchmark(size: 100, iterations: 10)
# %{
#   time_ms: 1.234,
#   gflops: 0.857,
#   backend: {EMLX.Backend, [device: :gpu]},
#   burn: %{time_ms: 0.567, gflops: 1.234}  # if ExBurn available
# }

Platform Notes

iOS

  • Uses Metal GPU backend via Burn's CubeCL
  • No JIT required (unlike EMLX on devices)
  • Training small models (< 10M params) is feasible
  • Inference is the primary use case

Android

  • Uses Vulkan GPU backend via Burn's CubeCL
  • Same training/inference capabilities as iOS

Desktop (Development)

  • Uses Metal (macOS) or Vulkan (Linux)
  • CUDA support available on NVIDIA hardware

Training on Mobile — Caveats

Burn's Autodiff backend is memory-intensive. On iOS/Android with limited RAM:

  • Fine-tuning small models (< 10M parameters) is feasible on modern devices
  • Full training of large models is not recommended on mobile
  • Inference is the primary use case for mobile deployment
  • Minimum recommended: 4GB RAM, A12+ chip (iOS) / Snapdragon 700+ (Android)

The training loop in ExBurn uses autodiff gradients (Burn Autodiff backward pass) with numerical fallback. The autodiff path provides exact gradient computation and is significantly faster than numerical methods.

Comparison with Other Dala ML Backends

BackendBest ForGPUTrainingiOSAndroid
EMLXiOS inferenceMetal (MLX)
CoreMLiOS Neural EngineANE
ONNXCross-platformNNAPI/CoreML
GPU ComputeCustom kernelsCubeCLN/A
ExBurnTraining + inferenceCubeCL

Error Handling

All operations raise ExBurn.Error with structured context:

raise ExBurn.Error,
  op: :matmul,
  reason: "shape mismatch",
  details: %{lhs: [3, 4], rhs: [5, 6]}

Troubleshooting

available?() returns false

  • Ensure ex_burn is in your deps: {:ex_burn, "~> 0.4"}
  • Run mix deps.get && mix compile
  • The Rust NIF will be compiled automatically via Rustler

gpu?() returns false

  • ExBurn checks ExCubecl availability for GPU detection
  • On iOS/Android, ensure the GPU compute libraries are linked
  • On desktop, GPU may not be available — falls back to CPU

Training is slow

  • Training uses autodiff gradients by default (fast, exact)
  • If autodiff falls back to numerical gradients, performance will be slower
  • For faster training, use EMLX or cloud training and deploy to device
  • Reduce batch size and model size for mobile
  • Use profile_step/3 to identify bottlenecks (forward/backward/optimizer timing)

Out of memory during training

See Also