Training callbacks and utilities for ExBurn models in Dala.
Provides Dala-specific training callbacks and helpers that integrate with the ExBurn training loop.
Callbacks
Dala.ML.Burn.Training.LoggingCallback— Logs metrics after each epochDala.ML.Burn.Training.EarlyStoppingCallback— Stops training when val loss plateausDala.ML.Burn.Training.CheckpointCallback— Saves model checkpoints
Usage
model = Dala.ML.Burn.compile(axon_model, loss: :cross_entropy, optimizer: :adam)
callbacks = [
&Dala.ML.Burn.Training.LoggingCallback.log/1,
Dala.ML.Burn.Training.EarlyStoppingCallback.wait(5, 1.0e-4),
Dala.ML.Burn.Training.CheckpointCallback.every(5, "checkpoints/")
]
trained = Dala.ML.Burn.fit(model, {train_x, train_y},
epochs: 50,
batch_size: 32,
validation_data: {val_x, val_y},
callbacks: callbacks,
lr_schedule: {:cosine, 0.001, 1.0e-5},
clip_norm: 1.0
)
Summary
Functions
Creates a checkpoint callback that saves model state at intervals.
Creates an early stopping callback.
Convenience function to train with a progress-reporting callback.
Retrieves the full training history from a history callback agent.
Creates a callback that stores training history in an Agent.
Creates a logging callback that prints training metrics after each epoch.
Creates a Dala-specific callback that reports training progress
to the Dala screen via handle_info.
Functions
@spec checkpoint_callback(pos_integer(), Path.t()) :: (map() -> map())
Creates a checkpoint callback that saves model state at intervals.
Parameters
interval— Save a checkpoint every N epochsdir— Directory to save checkpoints in
@spec early_stopping_callback(pos_integer(), float()) :: (map() -> map())
Creates an early stopping callback.
Parameters
patience— Number of epochs to wait for improvement before stoppingmin_delta— Minimum improvement to reset the patience counter
@spec fit_with_progress(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) :: {ExBurn.Model.t(), [map()]}
Convenience function to train with a progress-reporting callback.
Automatically sets up screen callback and returns the trained model.
Retrieves the full training history from a history callback agent.
Creates a callback that stores training history in an Agent.
Returns {callback_fn, agent_pid}. Call get_history/1 on the agent
to retrieve the full training history.
Usage
{callback, agent} = Dala.ML.Burn.Training.history_callback()
callbacks = [callback]
# After training:
history = Dala.ML.Burn.Training.get_history(agent)
# [%{epoch: 1, loss: 0.5, val_loss: 0.4}, ...]
Creates a logging callback that prints training metrics after each epoch.
Creates a Dala-specific callback that reports training progress
to the Dala screen via handle_info.
The callback sends {:training_progress, epoch, loss, val_loss} to
the calling process, which can be handled in handle_info/2.
Usage
# In your screen module:
callbacks = [
Dala.ML.Burn.Training.screen_callback(self())
]
# Handle progress updates:
def handle_info({:training_progress, epoch, loss, val_loss}, socket) do
{:noreply, assign(socket, epoch: epoch, loss: loss, val_loss: val_loss)}
end