View Source Reinforcement Learning

my_app_root = Path.join(__DIR__, "..")

Mix.install(
  [
    {:rein, path: my_app_root, env: :dev}
  ],
  config_path: Path.join(my_app_root, "config/config.exs"),
  lockfile: Path.join(my_app_root, "mix.lock"),
  # change to "cuda118" to use CUDA
  system_env: %{"XLA_TARGET" => "cpu"}
)

Section

alias VegaLite, as: Vl

{min_x, max_x, min_y, max_y} = Rein.Environments.Gridworld.bounding_box()

possible_targets_l = [[div(min_x + max_x, 2), max_y - 2]]

possible_targets_l =
  for x <- (min_x + 2)..(max_x - 2), y <- 2..max_y do
    [x, y]
  end

possible_targets = Nx.tensor(Enum.shuffle(possible_targets_l))

width = 600
height = 600

grid_widget =
  Vl.new(width: width, height: height)
  |> Vl.layers([
    Vl.new()
    |> Vl.data(name: "target")
    |> Vl.mark(:point,
      fill: true,
      tooltip: [content: "data"],
      grid: true,
      size: [expr: "height * 4 * #{:math.pi()} / #{max_y - min_y}"]
    )
    |> Vl.encode_field(:x, "x", type: :quantitative)
    |> Vl.encode_field(:y, "y", type: :quantitative)
    |> Vl.encode_field(:color, "episode",
      type: :nominal,
      scale: [scheme: "blues"],
      legend: false
    ),
    Vl.new()
    |> Vl.data(name: "trajectory")
    |> Vl.mark(:line, opacity: 0.5, tooltip: [content: "data"])
    |> Vl.encode_field(:x, "x", type: :quantitative, scale: [domain: [min_x, max_x], clamp: true])
    |> Vl.encode_field(:y, "y", type: :quantitative, scale: [domain: [min_y, max_y], clamp: true])
    |> Vl.encode_field(:color, "episode",
      type: :nominal,
      scale: [scheme: "blues"],
      legend: false
    )
    |> Vl.encode_field(:order, "index")
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

loss_widget =
  Vl.new(width: width, height: height, title: "Loss")
  |> Vl.data(name: "loss")
  |> Vl.mark(:line,
    grid: true,
    tooltip: [content: "data"],
    interpolate: "step-after",
    point: true,
    color: "blue"
  )
  |> Vl.encode_field(:x, "episode", type: :quantitative)
  |> Vl.encode_field(:y, "loss",
    type: :quantitative,
    scale: [
      domain: [0, 1],
      type: "linear",
      base: 10,
      clamp: true
    ]
  )
  |> Vl.encode_field(:order, "episode")
  |> Kino.VegaLite.new()
  |> Kino.render()

reward_widget =
  Vl.new(width: width, height: height, title: "Total Reward per Epoch")
  |> Vl.data(name: "reward")
  |> Vl.mark(:line,
    grid: true,
    tooltip: [content: "data"],
    interpolate: "step-after",
    point: true,
    color: "blue"
  )
  |> Vl.encode_field(:x, "episode", type: :quantitative)
  |> Vl.encode_field(:y, "reward",
    type: :quantitative,
    scale: [
      domain: [-2, 2],
      type: "symlog",
      base: 10,
      clamp: true
    ]
  )
  |> Vl.encode_field(:order, "episode")
  |> Kino.VegaLite.new()
  |> Kino.render()

nil
# 250 max_iter * 15 episodes
max_points = 1000

plot_fn = fn axon_state ->
  if axon_state.iteration > 1 do
    episode = axon_state.epoch

    Kino.VegaLite.clear(grid_widget, dataset: "target")
    Kino.VegaLite.clear(grid_widget, dataset: "trajectory")

    Kino.VegaLite.push(
      grid_widget,
      %{
        x: Nx.to_number(axon_state.step_state.environment_state.target_x),
        y: Nx.to_number(axon_state.step_state.environment_state.target_y),
        episode: episode,
        episode_group: rem(episode, 15)
      },
      dataset: "target"
    )

    Kino.VegaLite.push(
      loss_widget,
      %{
        episode: episode,
        loss: Nx.to_number(Nx.mean(axon_state.step_state.agent_state.loss))
      },
      dataset: "loss"
    )

    IO.inspect("Episode #{episode} ended")

    trajectory = axon_state.step_state.trajectory

    idx = Nx.to_flat_list(trajectory[0][0..(axon_state.iteration - 1)//1] |> Nx.as_type(:s64))
    x = Nx.to_flat_list(trajectory[1][0..(axon_state.iteration - 1)//1])
    y = Nx.to_flat_list(trajectory[2][0..(axon_state.iteration - 1)//1])

    points =
      [idx, x, y]
      |> Enum.zip_with(fn [index, x, y] ->
        %{
          x: x,
          y: y,
          index: index,
          episode: episode,
          episode_group: rem(episode, 15)
        }
      end)

    Kino.VegaLite.push_many(grid_widget, points, dataset: "trajectory")
  end

  Kino.VegaLite.push(
    reward_widget,
    %{
      episode: axon_state.epoch,
      reward: Nx.to_number(axon_state.step_state.agent_state.total_reward)
    },
    dataset: "reward"
  )

  axon_state
end
filename = "/Users/paulo.valente/Desktop/gridworld.bin"

{
  q_policy,
  experience_replay_buffer_index,
  persisted_experience_replay_buffer_entries,
  experience_replay_buffer,
  total_episodes
} =
  try do
    contents = File.read!(filename)
    File.write!(filename <> "_bak", contents)

    %{serialized: serialized, total_episodes: total_episodes} = :erlang.binary_to_term(contents)

    %{
      q_policy: q_policy,
      experience_replay_buffer_index: experience_replay_buffer_index,
      persisted_experience_replay_buffer_entries: persisted_experience_replay_buffer_entries,
      experience_replay_buffer: exp_replay_buffer
    } = Nx.deserialize(serialized)

    {q_policy, experience_replay_buffer_index, persisted_experience_replay_buffer_entries,
     exp_replay_buffer, total_episodes}
  rescue
    File.Error ->
      {%{}, 0, 0, nil, 0}
  end

# q_policy = %{}
# total_episodes = 0
# experience_replay_buffer = nil
# persisted_experience_replay_buffer_entries = 0
# experience_replay_buffer_index = 0
total_episodes
q_policy
num_observations = Rein.Environments.Gridworld.state_vector_size()
num_actions = Rein.Environments.Gridworld.num_actions()

policy_net =
  Axon.input("state", shape: {nil, num_observations})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(64, activation: :relu)
  |> Axon.dense(64, activation: :relu)
  |> Axon.dense(num_actions)

# These might seem redundant, but will make more sense for multi-input models

environment_to_input_fn = fn env_state ->
  %{"state" => Rein.Environments.Gridworld.as_state_vector(env_state)}
end

state_vector_to_input_fn = fn state_vector ->
  %{"state" => state_vector}
end

environment_to_state_vector_fn = &Rein.Environments.Gridworld.as_state_vector/1
Kino.VegaLite.clear(grid_widget)
Kino.VegaLite.clear(loss_widget, dataset: "loss")
Kino.VegaLite.clear(reward_widget, dataset: "reward")

episodes = 5000
max_iter = 200

{t, result} =
  :timer.tc(fn ->
    Rein.train(
      {Rein.Environments.Gridworld, possible_targets: possible_targets},
      {Rein.Agents.DQN,
       policy_net: policy_net,
       eps_max_iter: -1,
       q_policy: q_policy,
       persisted_experience_replay_buffer_entries: persisted_experience_replay_buffer_entries,
       experience_replay_buffer_index: experience_replay_buffer_index,
       experience_replay_buffer: experience_replay_buffer,
       environment_to_input_fn: environment_to_input_fn,
       environment_to_state_vector_fn: environment_to_state_vector_fn,
       state_vector_to_input_fn: state_vector_to_input_fn},
      plot_fn,
      num_episodes: episodes,
      max_iter: max_iter
    )
  end)

# File.write!("/Users/paulo.valente/Desktop/results_#{NaiveDateTime.utc_now() |> NaiveDateTime.to_iso8601() |> String.replace(":", "")}.bin", )
serialized =
  Nx.serialize(
    Map.take(result.step_state.agent_state, [
      :q_policy,
      :experience_replay_buffer_index,
      :experience_replay_buffer,
      :persisted_experience_replay_buffer_entries
    ])
  )

contents =
  :erlang.term_to_binary(%{serialized: serialized, total_episodes: total_episodes + episodes})

# File.write!(filename, contents)

"#{Float.round(t / 1_000_000, 3)} s"
File.write!(filename, contents)
result.step_state.agent_state.q_policy