Training loop implementation for ExBurn models.
Provides a flexible training loop with support for:
- Mini-batch training with real gradient computation
- Multiple optimizers (Adam, SGD with momentum, RMSprop)
- Learning rate scheduling (step, exponential, cosine)
- Gradient clipping (by norm and by value)
- Validation
- Callbacks (logging, early stopping, checkpointing)
- GPU-accelerated gradient computation via Burn
Usage
model = ExBurn.Model.compile(axon_model, loss: :cross_entropy, optimizer: :adam)
opts = [
epochs: 10,
batch_size: 32,
validation_data: val_data,
lr_schedule: {:cosine, 0.001, 1.0e-5},
clip_norm: 1.0,
callbacks: [&ExBurn.Training.LoggingCallback.log/2]
]
trained_model = ExBurn.Training.fit(model, train_data, opts)
Summary
Functions
Creates a data loader that yields mini-batches from a dataset.
Evaluates a model on a dataset.
Trains a model on the given dataset.
Types
@type dataset() :: {Nx.Tensor.t(), Nx.Tensor.t()}
@type training_opts() :: [ epochs: pos_integer(), batch_size: pos_integer(), validation_data: dataset() | nil, callbacks: [callback()], verbose: boolean(), lr_schedule: lr_schedule(), clip_norm: float() | nil, clip_value: float() | nil ]
Functions
@spec data_loader( dataset(), keyword() ) :: Enumerable.t()
Creates a data loader that yields mini-batches from a dataset.
@spec evaluate(ExBurn.Model.t(), dataset()) :: float()
Evaluates a model on a dataset.
Returns the average loss over the entire dataset.
@spec fit(ExBurn.Model.t(), dataset(), keyword()) :: ExBurn.Model.t()
Trains a model on the given dataset.
Options
:epochs— Number of training epochs (default: 10):batch_size— Mini-batch size (default: 32):validation_data— Validation dataset as{inputs, targets}tuple:callbacks— List of callback functions called after each epoch:verbose— Print training progress (default: true):lr_schedule— Learning rate schedule (default: nil):clip_norm— Max gradient norm for clipping (default: nil):clip_value— Max absolute gradient value for clipping (default: nil)
Returns
The trained ExBurn.Model struct with updated parameters.