Training Models with ExBurn

Copy Markdown View Source

Defining a Model with Axon

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu, name: "hidden1")
  |> Axon.dropout(rate: 0.2)
  |> Axon.dense(64, activation: :relu, name: "hidden2")
  |> Axon.dense(10, name: "output")

Compiling

compiled = ExBurn.Model.compile(model,
  loss: :cross_entropy,     # or :mse, :binary_cross_entropy
  optimizer: :adam,         # or :sgd, :rmsprop
  learning_rate: 0.001
)

Training

trained = ExBurn.Training.fit(compiled, {train_x, train_y},
  epochs: 10,
  batch_size: 32,
  validation_data: {val_x, val_y},
  verbose: true
)

Callbacks

Logging

ExBurn.Training.fit(model, data,
  callbacks: [&ExBurn.Training.LoggingCallback.log/1]
)

Early Stopping

ExBurn.Training.fit(model, data,
  callbacks: [ExBurn.Training.EarlyStoppingCallback.wait(5, 1.0e-4)]
)

Checkpointing

ExBurn.Training.fit(model, data,
  callbacks: [ExBurn.Training.CheckpointCallback.every(5, "/checkpoints")]
)

Custom Callbacks

custom_callback = fn
  %{epoch: epoch, loss: loss} when loss < 0.01 ->
    IO.puts("Converged at epoch #{epoch}!")
    %{epoch: epoch, loss: loss, stop_training: true}

  metrics ->
    metrics
end

Saving and Loading

# Save
ExBurn.Model.save(trained, "model.bin")

# Load
{:ok, model} = ExBurn.Model.load(trained, "model.bin")

Inference

# Single prediction
output = Axon.predict(model, params, input)

# Batch prediction
outputs = Axon.predict(model, params, batch_input)