ExBurn.Model (ex_burn v0.1.0)

Copy Markdown View Source

Model definition and training orchestration for ExBurn.

This module provides a high-level API for defining, compiling, and training neural network models using Axon with the ExBurn backend.

Usage

# Define a model
model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(256)
  |> Axon.relu()
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(128)
  |> Axon.relu()
  |> Axon.dense(10)

# Compile with ExBurn backend
compiled = ExBurn.Model.compile(model, loss: :cross_entropy, optimizer: :adam)

# Train
ExBurn.Model.fit(compiled, train_data, epochs: 10, batch_size: 32)

Summary

Functions

Compiles an Axon model for training with the ExBurn backend.

Computes the loss between predictions and targets.

Deserializes parameters from a binary.

Loads model parameters from a file.

Returns the model's loss function.

Returns the model's optimizer.

Returns the current model parameters.

Runs a forward pass through the model.

Saves the model parameters to a file using compressed Erlang term format.

Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.

Returns a summary of the model architecture including parameter count.

Types

t()

@type t() :: %ExBurn.Model{
  axon_model: Axon.ModelState.t(),
  compiled: boolean(),
  loss_fn: atom(),
  optimizer: atom(),
  optimizer_state: map(),
  params: map()
}

Functions

compile(axon_model, opts \\ [])

@spec compile(
  Axon.ModelState.t(),
  keyword()
) :: t()

Compiles an Axon model for training with the ExBurn backend.

Options

  • :loss — Loss function: :cross_entropy, :mse, :binary_cross_entropy (default: :cross_entropy)
  • :optimizer — Optimizer: :adam, :sgd, :rmsprop (default: :adam)
  • :learning_rate — Learning rate (default: 0.001)
  • :device — Device: :cpu or :gpu (default: :gpu)

Returns

An ExBurn.Model struct ready for training.

compute_loss(model, pred, target)

@spec compute_loss(t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  {:ok, Nx.Tensor.t()} | {:error, String.t()}

Computes the loss between predictions and targets.

Supports :cross_entropy (with log-softmax numerical stability) and :mse.

deserialize_params(binary)

@spec deserialize_params(binary()) :: {:ok, map()} | {:error, String.t()}

Deserializes parameters from a binary.

load(model, path)

@spec load(t(), Path.t()) :: {:ok, t()} | {:error, String.t()}

Loads model parameters from a file.

loss_function(model)

@spec loss_function(t()) :: atom()

Returns the model's loss function.

optimizer(model)

@spec optimizer(t()) :: atom()

Returns the model's optimizer.

parameters(model)

@spec parameters(t()) :: map()

Returns the current model parameters.

predict(model, input)

@spec predict(t(), Nx.Tensor.t()) :: {:ok, Nx.Tensor.t()} | {:error, String.t()}

Runs a forward pass through the model.

Returns the output tensor as an Nx tensor.

save(model, path)

@spec save(t(), Path.t()) :: :ok | {:error, String.t()}

Saves the model parameters to a file using compressed Erlang term format.

serialize_params(model)

@spec serialize_params(t()) :: binary()

Serializes parameters to a binary for network transfer or storage. Uses compressed Erlang term format.

summary(model)

@spec summary(t()) :: String.t()

Returns a summary of the model architecture including parameter count.