Dala integration for the Burn deep learning framework.
ExBurn provides a Nx.Backend implementation that delegates tensor operations
to Burn via Rust NIFs, enabling GPU-accelerated ML/DL on mobile and desktop.
Architecture
Axon model
↓
Nx.Defn graph
↓
ExBurn.Backend (Nx.Backend behaviour)
↓
ExBurn.Nif (Rustler NIF) ←→ ExCubecl (GPU buffers, kernels, pipelines)
↓
Burn Autodiff<CubeCL> (Rust)
↓
CubeCL kernels
↓
Metal (iOS) / Vulkan (Android) / CUDA → GPUQuick Start
# Set ExBurn as the default Nx backend
Dala.ML.Burn.configure!()
# Create and manipulate tensors
t = Nx.tensor([1.0, 2.0, 3.0])
Nx.add(t, t) |> Nx.to_list()
# Define a model with Axon
model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(256, activation: :relu)
|> Axon.dropout(rate: 0.2)
|> Axon.dense(10)
# Compile for training
compiled = Dala.ML.Burn.compile(model,
loss: :cross_entropy,
optimizer: :adam,
learning_rate: 0.001
)
# Train
Dala.ML.Burn.fit(compiled, {train_x, train_y},
epochs: 10,
batch_size: 32
)Platform GPU Backends
| Platform | Backend | Status |
|---|---|---|
| iOS | Metal | ✅ |
| Android | Vulkan | ✅ |
| macOS | Metal | ✅ |
| Linux | Vulkan | ✅ |
| NVIDIA | CUDA | 🔜 |
Integration with Dala.ML
This module complements the existing Dala ML backends:
Dala.ML.EMLX— MLX backend for Apple Silicon (iOS recommended)Dala.ML.CoreML— iOS-native CoreML (Neural Engine)Dala.ML.ONNX— Cross-platform ONNX RuntimeDala.ML.Burn— Burn framework via ExBurn (this module)
Use Dala.ML.available_backends/0 to see all available backends,
and Dala.ML.Burn.available?/0 specifically for Burn support.
Training on Mobile — Caveats
Burn's Autodiff backend is memory-intensive. On iOS/Android with limited RAM:
- Fine-tuning small models (< 10M parameters) is feasible on modern devices
- Full training of large models is not recommended on mobile
- Inference is the primary use case for mobile deployment
- Minimum recommended: 4GB RAM, A12+ chip (iOS) / Snapdragon 700+ (Android)
Summary
Functions
Checks whether ExBurn is available (loaded and functional).
Returns a list of available GPU backends on this system.
Compiles an Axon model for training with the ExBurn backend.
Computes the loss between predictions and targets.
Sets ExBurn as the default Nx backend.
Configures ExBurn for the current platform with Dala-specific defaults.
Creates a data loader that yields mini-batches from a dataset.
Returns the default device for tensor operations.
Evaluates a model on a dataset.
Trains a model on the given dataset.
Frees a Burn tensor's underlying GPU/CPU memory.
Converts an Nx tensor to a Burn tensor.
Checks whether a GPU device is available for Burn operations.
Loads model parameters from a file.
Creates a tensor filled with ones via Burn.
Returns the current model parameters.
Runs a forward pass through the model.
Saves the model parameters to a file.
Returns a summary of the model architecture including parameter count.
Converts a Burn tensor to an Nx tensor.
Returns the current ExBurn version.
Creates a tensor filled with zeros via Burn.
Functions
@spec available?() :: boolean()
Checks whether ExBurn is available (loaded and functional).
@spec available_backends() :: [atom()]
Returns a list of available GPU backends on this system.
@spec compile( Axon.ModelState.t(), keyword() ) :: ExBurn.Model.t()
Compiles an Axon model for training with the ExBurn backend.
Options
:loss— Loss function::cross_entropy,:mse,:binary_cross_entropy(default::cross_entropy):optimizer— Optimizer::adam,:sgd,:rmsprop(default::adam):learning_rate— Learning rate (default: 0.001):device— Device::cpuor:gpu(default: auto-detected)
@spec compute_loss(ExBurn.Model.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, term()}
Computes the loss between predictions and targets.
@spec configure!() :: :ok
Sets ExBurn as the default Nx backend.
After calling this, all Nx operations will be executed via Burn.
For Dala apps, prefer Dala.ML.Burn.configure!/1 which also
handles platform-specific GPU setup.
@spec configure!(keyword()) :: :ok
Configures ExBurn for the current platform with Dala-specific defaults.
Options
:device— Override device (:cpuor:gpu). Auto-detected by default.:backend— Override GPU backend (:metal,:vulkan,:cuda). Auto-detected.
@spec data_loader({Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) :: Enumerable.t()
Creates a data loader that yields mini-batches from a dataset.
@spec default_device() :: :cpu | :gpu
Returns the default device for tensor operations.
@spec evaluate( ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()} ) :: float()
Evaluates a model on a dataset.
Returns the average loss over the entire dataset.
@spec fit(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) :: ExBurn.Model.t()
Trains a model on the given dataset.
Options
:epochs— Number of training epochs (default: 10):batch_size— Mini-batch size (default: 32):validation_data— Validation dataset as{inputs, targets}tuple:callbacks— List of callback functions called after each epoch:verbose— Print training progress (default: true):lr_schedule— Learning rate schedule (default: nil):clip_norm— Max gradient norm for clipping (default: nil):clip_value— Max absolute gradient value for clipping (default: nil)
@spec free(ExBurn.Tensor.t()) :: :ok
Frees a Burn tensor's underlying GPU/CPU memory.
@spec from_nx(Nx.Tensor.t()) :: {:ok, ExBurn.Tensor.t()} | {:error, term()}
Converts an Nx tensor to a Burn tensor.
@spec gpu?() :: boolean()
Checks whether a GPU device is available for Burn operations.
Delegates to ExBurn's GPU detection which checks ExCubecl availability.
@spec load(ExBurn.Model.t(), Path.t()) :: {:ok, ExBurn.Model.t()} | {:error, term()}
Loads model parameters from a file.
@spec ones([non_neg_integer()], atom()) :: ExBurn.Tensor.t()
Creates a tensor filled with ones via Burn.
@spec parameters(ExBurn.Model.t()) :: map()
Returns the current model parameters.
@spec predict(ExBurn.Model.t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, term()}
Runs a forward pass through the model.
Returns {:ok, output_tensor} or {:error, reason}.
@spec save(ExBurn.Model.t(), Path.t()) :: :ok | {:error, term()}
Saves the model parameters to a file.
@spec summary(ExBurn.Model.t()) :: String.t()
Returns a summary of the model architecture including parameter count.
@spec to_nx(ExBurn.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, term()}
Converts a Burn tensor to an Nx tensor.
@spec version() :: String.t()
Returns the current ExBurn version.
@spec zeros([non_neg_integer()], atom()) :: ExBurn.Tensor.t()
Creates a tensor filled with zeros via Burn.
For performance-critical paths, use the Burn bridge directly instead of going through Nx.