Dala.ML.Burn (dala v0.6.0)

Copy Markdown View Source

Dala integration for the Burn deep learning framework.

ExBurn provides a Nx.Backend implementation that delegates tensor operations to Burn via Rust NIFs, enabling GPU-accelerated ML/DL on mobile and desktop.

Architecture

Axon model
   
Nx.Defn graph
   
ExBurn.Backend (Nx.Backend behaviour)
   
ExBurn.Nif (Rustler NIF)  ExCubecl (GPU buffers, kernels, pipelines)
   
Burn Autodiff<CubeCL> (Rust)
   
CubeCL kernels
   
Metal (iOS) / Vulkan (Android) / CUDA  GPU

Quick Start

# Set ExBurn as the default Nx backend
Dala.ML.Burn.configure!()

# Create and manipulate tensors
t = Nx.tensor([1.0, 2.0, 3.0])
Nx.add(t, t) |> Nx.to_list()

# Define a model with Axon
model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(256, activation: :relu)
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(10)

# Compile for training
compiled = Dala.ML.Burn.compile(model,
  loss: :cross_entropy,
  optimizer: :adam,
  learning_rate: 0.001
)

# Train
Dala.ML.Burn.fit(compiled, {train_x, train_y},
  epochs: 10,
  batch_size: 32
)

Platform GPU Backends

PlatformBackendStatus
iOSMetal
AndroidVulkan
macOSMetal
LinuxVulkan
NVIDIACUDA🔜

Integration with Dala.ML

This module complements the existing Dala ML backends:

Use Dala.ML.available_backends/0 to see all available backends, and Dala.ML.Burn.available?/0 specifically for Burn support.

Training on Mobile — Caveats

Burn's Autodiff backend is memory-intensive. On iOS/Android with limited RAM:

  • Fine-tuning small models (< 10M parameters) is feasible on modern devices
  • Full training of large models is not recommended on mobile
  • Inference is the primary use case for mobile deployment
  • Minimum recommended: 4GB RAM, A12+ chip (iOS) / Snapdragon 700+ (Android)

Summary

Functions

Checks whether ExBurn is available (loaded and functional).

Returns a list of available GPU backends on this system.

Compiles an Axon model for training with the ExBurn backend.

Computes the loss between predictions and targets.

Sets ExBurn as the default Nx backend.

Configures ExBurn for the current platform with Dala-specific defaults.

Creates a data loader that yields mini-batches from a dataset.

Returns the default device for tensor operations.

Evaluates a model on a dataset.

Trains a model on the given dataset.

Frees a Burn tensor's underlying GPU/CPU memory.

Converts an Nx tensor to a Burn tensor.

Checks whether a GPU device is available for Burn operations.

Loads model parameters from a file.

Creates a tensor filled with ones via Burn.

Returns the current model parameters.

Runs a forward pass through the model.

Saves the model parameters to a file.

Returns a summary of the model architecture including parameter count.

Converts a Burn tensor to an Nx tensor.

Returns the current ExBurn version.

Creates a tensor filled with zeros via Burn.

Functions

available?()

@spec available?() :: boolean()

Checks whether ExBurn is available (loaded and functional).

available_backends()

@spec available_backends() :: [atom()]

Returns a list of available GPU backends on this system.

compile(model, opts \\ [])

@spec compile(
  Axon.ModelState.t(),
  keyword()
) :: ExBurn.Model.t()

Compiles an Axon model for training with the ExBurn backend.

Options

  • :loss — Loss function: :cross_entropy, :mse, :binary_cross_entropy (default: :cross_entropy)
  • :optimizer — Optimizer: :adam, :sgd, :rmsprop (default: :adam)
  • :learning_rate — Learning rate (default: 0.001)
  • :device — Device: :cpu or :gpu (default: auto-detected)

compute_loss(model, pred, target)

@spec compute_loss(ExBurn.Model.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, term()}

Computes the loss between predictions and targets.

configure!()

@spec configure!() :: :ok

Sets ExBurn as the default Nx backend.

After calling this, all Nx operations will be executed via Burn. For Dala apps, prefer Dala.ML.Burn.configure!/1 which also handles platform-specific GPU setup.

configure!(opts \\ [])

@spec configure!(keyword()) :: :ok

Configures ExBurn for the current platform with Dala-specific defaults.

Options

  • :device — Override device (:cpu or :gpu). Auto-detected by default.
  • :backend — Override GPU backend (:metal, :vulkan, :cuda). Auto-detected.

data_loader(arg, opts \\ [])

@spec data_loader({Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) :: Enumerable.t()

Creates a data loader that yields mini-batches from a dataset.

default_device()

@spec default_device() :: :cpu | :gpu

Returns the default device for tensor operations.

evaluate(model, arg)

@spec evaluate(
  ExBurn.Model.t(),
  {Nx.Tensor.t(), Nx.Tensor.t()}
) :: float()

Evaluates a model on a dataset.

Returns the average loss over the entire dataset.

fit(model, arg, opts \\ [])

Trains a model on the given dataset.

Options

  • :epochs — Number of training epochs (default: 10)
  • :batch_size — Mini-batch size (default: 32)
  • :validation_data — Validation dataset as {inputs, targets} tuple
  • :callbacks — List of callback functions called after each epoch
  • :verbose — Print training progress (default: true)
  • :lr_schedule — Learning rate schedule (default: nil)
  • :clip_norm — Max gradient norm for clipping (default: nil)
  • :clip_value — Max absolute gradient value for clipping (default: nil)

free(bt)

@spec free(ExBurn.Tensor.t()) :: :ok

Frees a Burn tensor's underlying GPU/CPU memory.

from_nx(tensor)

@spec from_nx(Nx.Tensor.t()) :: {:ok, ExBurn.Tensor.t()} | {:error, term()}

Converts an Nx tensor to a Burn tensor.

gpu?()

@spec gpu?() :: boolean()

Checks whether a GPU device is available for Burn operations.

Delegates to ExBurn's GPU detection which checks ExCubecl availability.

load(model, path)

@spec load(ExBurn.Model.t(), Path.t()) :: {:ok, ExBurn.Model.t()} | {:error, term()}

Loads model parameters from a file.

ones(shape, type \\ :f32)

@spec ones([non_neg_integer()], atom()) :: ExBurn.Tensor.t()

Creates a tensor filled with ones via Burn.

parameters(model)

@spec parameters(ExBurn.Model.t()) :: map()

Returns the current model parameters.

predict(model, input)

@spec predict(ExBurn.Model.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, term()}

Runs a forward pass through the model.

Returns {:ok, output_tensor} or {:error, reason}.

save(model, path)

@spec save(ExBurn.Model.t(), Path.t()) :: :ok | {:error, term()}

Saves the model parameters to a file.

summary(model)

@spec summary(ExBurn.Model.t()) :: String.t()

Returns a summary of the model architecture including parameter count.

to_nx(bt)

@spec to_nx(ExBurn.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, term()}

Converts a Burn tensor to an Nx tensor.

version()

@spec version() :: String.t()

Returns the current ExBurn version.

zeros(shape, type \\ :f32)

@spec zeros([non_neg_integer()], atom()) :: ExBurn.Tensor.t()

Creates a tensor filled with zeros via Burn.

For performance-critical paths, use the Burn bridge directly instead of going through Nx.