Training callbacks and utilities for ExBurn models in Dala.
Provides Dala-specific training callbacks and helpers that integrate with the ExBurn training loop.
Callbacks
Dala.ML.Burn.Training.LoggingCallback— Logs metrics after each epochDala.ML.Burn.Training.EarlyStoppingCallback— Stops training when val loss plateausDala.ML.Burn.Training.CheckpointCallback— Saves model checkpointsDala.ML.Burn.Training.WarmupCallback— Learning rate warmupDala.ML.Burn.Training.ReduceLROnPlateauCallback— Reduce LR on plateauDala.ML.Burn.Training.HistoryCallback— Records all metrics for analysis
Usage
model = Dala.ML.Burn.compile(axon_model, loss: :cross_entropy, optimizer: :adam)
callbacks = [
&Dala.ML.Burn.Training.LoggingCallback.log/1,
Dala.ML.Burn.Training.EarlyStoppingCallback.wait(5, 1.0e-4),
Dala.ML.Burn.Training.CheckpointCallback.every(5, "checkpoints/"),
Dala.ML.Burn.Training.WarmupCallback.linear(3, 1.0e-5, 0.001),
Dala.ML.Burn.Training.ReduceLROnPlateauCallback.new(patience: 3, factor: 0.5)
]
trained = Dala.ML.Burn.fit(model, {train_x, train_y},
epochs: 50,
batch_size: 32,
validation_data: {val_x, val_y},
callbacks: callbacks,
lr_schedule: {:cosine, 0.001, 1.0e-5},
clip_norm: 1.0,
accuracy: true
)
Summary
Functions
Creates a callback that stores training history in an Agent.
Creates a checkpoint callback that saves model state at intervals.
Creates an early stopping callback.
Convenience function to train with accuracy tracking.
Convenience function to train with a progress-reporting callback.
Retrieves the full training history from a history callback agent.
Creates a history callback that records all training metrics.
Creates a logging callback that prints training metrics after each epoch.
Creates a reduce-on-plateau callback.
Creates a Dala-specific callback that reports training progress
to the calling process via handle_info.
Creates a standard set of callbacks for common training scenarios.
Creates a learning rate warmup callback.
Functions
Creates a callback that stores training history in an Agent.
Returns {callback_fn, agent_pid}. Call get_history/1 on the agent
to retrieve the full training history.
Usage
{callback, agent} = Dala.ML.Burn.Training.agent_history_callback()
callbacks = [callback]
# After training:
history = Dala.ML.Burn.Training.get_history(agent)
# [%{epoch: 1, loss: 0.5, val_loss: 0.4}, ...]
@spec checkpoint_callback(pos_integer(), Path.t()) :: (map() -> map())
Creates a checkpoint callback that saves model state at intervals.
Parameters
interval— Save a checkpoint every N epochsdir— Directory to save checkpoints in
@spec early_stopping_callback(pos_integer(), float()) :: (map() -> map())
Creates an early stopping callback.
Parameters
patience— Number of epochs to wait for improvement before stoppingmin_delta— Minimum improvement to reset the patience counter
@spec fit_with_accuracy(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) :: ExBurn.Model.t()
Convenience function to train with accuracy tracking.
Shorthand for fit/3 with accuracy: true.
@spec fit_with_progress(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) :: {ExBurn.Model.t(), [map()]}
Convenience function to train with a progress-reporting callback.
Automatically sets up screen callback and returns the trained model.
Retrieves the full training history from a history callback agent.
Creates a history callback that records all training metrics.
Returns a callback function. Access history via
Dala.ML.Burn.Training.get_history(pid).
Creates a logging callback that prints training metrics after each epoch.
Creates a reduce-on-plateau callback.
Reduces learning rate when validation loss stops improving.
Options
:patience— Epochs to wait before reducing (default: 5):factor— Multiplicative factor for LR reduction (default: 0.5):min_lr— Minimum learning rate floor (default: 1.0e-6)
Creates a Dala-specific callback that reports training progress
to the calling process via handle_info.
The callback sends {:training_progress, epoch, loss, val_loss} to
the calling process, which can be handled in handle_info/2.
Usage
# In your screen module:
callbacks = [
Dala.ML.Burn.Training.screen_callback(self())
]
# Handle progress updates:
def handle_info({:training_progress, epoch, loss, val_loss}, socket) do
{:noreply, assign(socket, epoch: epoch, loss: loss, val_loss: val_loss)}
end
Creates a standard set of callbacks for common training scenarios.
Includes logging, early stopping, and checkpointing.
Options
:early_stopping_patience— Patience for early stopping (default: 5):checkpoint_interval— Checkpoint every N epochs (default: 10):checkpoint_dir— Directory for checkpoints (default: "checkpoints"):warmup_epochs— Number of warmup epochs (default: 0, disabled)
@spec warmup_callback(pos_integer(), float(), float()) :: (map() -> map())
Creates a learning rate warmup callback.
Gradually increases the learning rate from start_lr to target_lr
over warmup_epochs epochs. Helps stabilize early training.
Parameters
warmup_epochs— Number of warmup epochsstart_lr— Initial learning ratetarget_lr— Target learning rate after warmup