View Source Rein.Agents.SAC (rein v0.0.1)
Soft Actor-Critic implementation.
This assumes that the Actor network will output {nil, num_actions, 2}
,
where for each action they output the $\mu$ and $\sigma$ values of a random
normal distribution, and that the Critic network accepts "actions"
input with
shape {nil, num_actions}
, where the action is calculated by sampling from
said random distribution.
Actions are deemed to be in a continuous space of type :f32
.
For simplicity in the implementation of the Dual Q implementation,
:critic_params
is vectorized with the axis :critics
with default
size 2. Likewise, :critic_target_params
and :critic_optimizer_state
are also vectorized in the same way.
Vectorized axes from :random_key
are still propagated normally throughout
the agent state for parallel training.