On-device training support for Dala ML.
Provides fine-tuning of pre-trained Axon models on-device with progress callbacks. All training runs on the dirty CPU scheduler to avoid blocking the BEAM.
Usage
model = Axon.input("input", shape: {nil, 784})
|> Axon.dense(128, activation: :relu)
|> Axon.dense(10, activation: :softmax)
{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(Nx.template({1, 784}, :f32), %{})
# Fine-tune on local data
Dala.ML.Training.fine_tune(
model, params, {train_data, train_labels},
epochs: 5, batch_size: 32, learning_rate: 0.001,
progress: fn epoch, loss -> IO.puts("Epoch #{epoch}: loss=#{loss}") end
)
Summary
Functions
Evaluates a model on validation data.
Fine-tunes a model on-device with progress reporting.
Loads model parameters from a file.
Saves model parameters to a file for later loading.
Functions
Evaluates a model on validation data.
Returns {%{metrics: map()}, updated_params}.
Fine-tunes a model on-device with progress reporting.
Options
:epochs— Number of training epochs (default: 5):batch_size— Mini-batch size (default: 32):learning_rate— Optimizer learning rate (default: 0.001):optimizer— Optimizer function (default:Polaris.Optimizers.adam/1):loss— Loss function (default::categorical_cross_entropy):progress— Callback(epoch, loss) -> :ok(default: no-op):validation_data—{val_data, val_labels}tuple for eval
Loads model parameters from a file.
Saves model parameters to a file for later loading.