Model definition and training orchestration for ExBurn.
This module provides a high-level API for defining, compiling, and training neural network models using Axon with the ExBurn backend.
Usage
# Define a model
model =
Axon.input("input", shape: {nil, 784})
|> Axon.dense(256)
|> Axon.relu()
|> Axon.dropout(rate: 0.2)
|> Axon.dense(128)
|> Axon.relu()
|> Axon.dense(10)
# Compile with ExBurn backend
compiled = ExBurn.Model.compile(model, loss: :cross_entropy, optimizer: :adam)
# Train
ExBurn.Model.fit(compiled, train_data, epochs: 10, batch_size: 32)
Summary
Functions
Compiles an Axon model for training with the ExBurn backend.
Computes the loss between predictions and targets.
Deserializes parameters from a binary.
Loads model parameters from a file.
Returns the model's loss function.
Returns the model's optimizer.
Returns the current model parameters.
Runs a forward pass through the model.
Saves the model parameters to a file using compressed Erlang term format.
Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.
Returns a summary of the model architecture including parameter count.
Types
Functions
Compiles an Axon model for training with the ExBurn backend.
Options
:loss— Loss function::cross_entropy,:mse,:binary_cross_entropy(default::cross_entropy):optimizer— Optimizer::adam,:sgd,:rmsprop(default::adam):learning_rate— Learning rate (default: 0.001):device— Device::cpuor:gpu(default::gpu)
Returns
An ExBurn.Model struct ready for training.
@spec compute_loss(t(), Nx.Tensor.t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
Computes the loss between predictions and targets.
Supports :cross_entropy (with log-softmax numerical stability) and :mse.
Deserializes parameters from a binary.
Loads model parameters from a file.
Returns the model's loss function.
Returns the model's optimizer.
Returns the current model parameters.
@spec predict(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}
Runs a forward pass through the model.
Returns the output tensor as an Nx tensor.
Saves the model parameters to a file using compressed Erlang term format.
Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.
Returns a summary of the model architecture including parameter count.