Dala.ML.Burn.Training (dala v0.6.1)

Copy Markdown View Source

Training callbacks and utilities for ExBurn models in Dala.

Provides Dala-specific training callbacks and helpers that integrate with the ExBurn training loop.

Callbacks

  • Dala.ML.Burn.Training.LoggingCallback — Logs metrics after each epoch
  • Dala.ML.Burn.Training.EarlyStoppingCallback — Stops training when val loss plateaus
  • Dala.ML.Burn.Training.CheckpointCallback — Saves model checkpoints
  • Dala.ML.Burn.Training.WarmupCallback — Learning rate warmup
  • Dala.ML.Burn.Training.ReduceLROnPlateauCallback — Reduce LR on plateau
  • Dala.ML.Burn.Training.HistoryCallback — Records all metrics for analysis

Usage

model = Dala.ML.Burn.compile(axon_model, loss: :cross_entropy, optimizer: :adam)

callbacks = [
  &Dala.ML.Burn.Training.LoggingCallback.log/1,
  Dala.ML.Burn.Training.EarlyStoppingCallback.wait(5, 1.0e-4),
  Dala.ML.Burn.Training.CheckpointCallback.every(5, "checkpoints/"),
  Dala.ML.Burn.Training.WarmupCallback.linear(3, 1.0e-5, 0.001),
  Dala.ML.Burn.Training.ReduceLROnPlateauCallback.new(patience: 3, factor: 0.5)
]

trained = Dala.ML.Burn.fit(model, {train_x, train_y},
  epochs: 50,
  batch_size: 32,
  validation_data: {val_x, val_y},
  callbacks: callbacks,
  lr_schedule: {:cosine, 0.001, 1.0e-5},
  clip_norm: 1.0,
  accuracy: true
)

Summary

Functions

Creates a callback that stores training history in an Agent.

Creates a checkpoint callback that saves model state at intervals.

Creates an early stopping callback.

Convenience function to train with accuracy tracking.

Convenience function to train with a progress-reporting callback.

Retrieves the full training history from a history callback agent.

Creates a history callback that records all training metrics.

Creates a logging callback that prints training metrics after each epoch.

Creates a reduce-on-plateau callback.

Creates a Dala-specific callback that reports training progress to the calling process via handle_info.

Creates a standard set of callbacks for common training scenarios.

Creates a learning rate warmup callback.

Functions

agent_history_callback()

@spec agent_history_callback() :: {(map() -> map()), pid()}

Creates a callback that stores training history in an Agent.

Returns {callback_fn, agent_pid}. Call get_history/1 on the agent to retrieve the full training history.

Usage

{callback, agent} = Dala.ML.Burn.Training.agent_history_callback()
callbacks = [callback]

# After training:
history = Dala.ML.Burn.Training.get_history(agent)
# [%{epoch: 1, loss: 0.5, val_loss: 0.4}, ...]

checkpoint_callback(interval, dir)

@spec checkpoint_callback(pos_integer(), Path.t()) :: (map() -> map())

Creates a checkpoint callback that saves model state at intervals.

Parameters

  • interval — Save a checkpoint every N epochs
  • dir — Directory to save checkpoints in

early_stopping_callback(patience, min_delta \\ 0.0001)

@spec early_stopping_callback(pos_integer(), float()) :: (map() -> map())

Creates an early stopping callback.

Parameters

  • patience — Number of epochs to wait for improvement before stopping
  • min_delta — Minimum improvement to reset the patience counter

fit_with_accuracy(model, data, opts \\ [])

@spec fit_with_accuracy(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) ::
  ExBurn.Model.t()

Convenience function to train with accuracy tracking.

Shorthand for fit/3 with accuracy: true.

fit_with_progress(model, data, opts \\ [])

@spec fit_with_progress(ExBurn.Model.t(), {Nx.Tensor.t(), Nx.Tensor.t()}, keyword()) ::
  {ExBurn.Model.t(), [map()]}

Convenience function to train with a progress-reporting callback.

Automatically sets up screen callback and returns the trained model.

get_history(pid)

@spec get_history(pid()) :: [map()]

Retrieves the full training history from a history callback agent.

history_callback()

@spec history_callback() :: (map() -> map())

Creates a history callback that records all training metrics.

Returns a callback function. Access history via Dala.ML.Burn.Training.get_history(pid).

logging_callback()

@spec logging_callback() :: (map() -> map())

Creates a logging callback that prints training metrics after each epoch.

reduce_on_plateau_callback(opts \\ [])

@spec reduce_on_plateau_callback(keyword()) :: (map() -> map())

Creates a reduce-on-plateau callback.

Reduces learning rate when validation loss stops improving.

Options

  • :patience — Epochs to wait before reducing (default: 5)
  • :factor — Multiplicative factor for LR reduction (default: 0.5)
  • :min_lr — Minimum learning rate floor (default: 1.0e-6)

screen_callback(screen_pid)

@spec screen_callback(pid()) :: (map() -> map())

Creates a Dala-specific callback that reports training progress to the calling process via handle_info.

The callback sends {:training_progress, epoch, loss, val_loss} to the calling process, which can be handled in handle_info/2.

Usage

# In your screen module:
callbacks = [
  Dala.ML.Burn.Training.screen_callback(self())
]

# Handle progress updates:
def handle_info({:training_progress, epoch, loss, val_loss}, socket) do
  {:noreply, assign(socket, epoch: epoch, loss: loss, val_loss: val_loss)}
end

standard_callbacks(opts \\ [])

@spec standard_callbacks(keyword()) :: [(map() -> map())]

Creates a standard set of callbacks for common training scenarios.

Includes logging, early stopping, and checkpointing.

Options

  • :early_stopping_patience — Patience for early stopping (default: 5)
  • :checkpoint_interval — Checkpoint every N epochs (default: 10)
  • :checkpoint_dir — Directory for checkpoints (default: "checkpoints")
  • :warmup_epochs — Number of warmup epochs (default: 0, disabled)

warmup_callback(warmup_epochs, start_lr, target_lr)

@spec warmup_callback(pos_integer(), float(), float()) :: (map() -> map())

Creates a learning rate warmup callback.

Gradually increases the learning rate from start_lr to target_lr over warmup_epochs epochs.

Parameters

  • warmup_epochs — Number of warmup epochs
  • start_lr — Initial learning rate
  • target_lr — Target learning rate after warmup