Dala.ML.Training (dala v0.3.1)

Copy Markdown View Source

On-device training support for Dala ML.

Provides fine-tuning of pre-trained Axon models on-device with progress callbacks. All training runs on the dirty CPU scheduler to avoid blocking the BEAM.

Usage

model = Axon.input("input", shape: {nil, 784})
        |> Axon.dense(128, activation: :relu)
        |> Axon.dense(10, activation: :softmax)

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 784}, :f32), %{})

# Fine-tune on local data
Dala.ML.Training.fine_tune(
  model, params, {train_data, train_labels},
  epochs: 5, batch_size: 32, learning_rate: 0.001,
  progress: fn epoch, loss -> IO.puts("Epoch #{epoch}: loss=#{loss}") end
)

Summary

Functions

Evaluates a model on validation data.

Fine-tunes a model on-device with progress reporting.

Loads model parameters from a file.

Saves model parameters to a file for later loading.

Functions

evaluate(model, params, arg, opts \\ [])

@spec evaluate(term(), term(), {term(), term()}, keyword()) ::
  {:ok, map()} | {:error, term()}

Evaluates a model on validation data.

Returns {%{metrics: map()}, updated_params}.

fine_tune(model, params, arg, opts \\ [])

@spec fine_tune(term(), term(), {term(), term()}, keyword()) ::
  {:ok, term()} | {:error, term()}

Fine-tunes a model on-device with progress reporting.

Options

  • :epochs — Number of training epochs (default: 5)
  • :batch_size — Mini-batch size (default: 32)
  • :learning_rate — Optimizer learning rate (default: 0.001)
  • :optimizer — Optimizer function (default: Polaris.Optimizers.adam/1)
  • :loss — Loss function (default: :categorical_cross_entropy)
  • :progress — Callback (epoch, loss) -> :ok (default: no-op)
  • :validation_data{val_data, val_labels} tuple for eval

load_params(path)

@spec load_params(String.t()) :: {:ok, term()} | {:error, term()}

Loads model parameters from a file.

save_params(params, path)

@spec save_params(term(), String.t()) :: :ok | {:error, term()}

Saves model parameters to a file for later loading.