View Source Rein.Agents.SAC (rein v0.0.1)

Soft Actor-Critic implementation.

This assumes that the Actor network will output {nil, num_actions, 2}, where for each action they output the $\mu$ and $\sigma$ values of a random normal distribution, and that the Critic network accepts "actions" input with shape {nil, num_actions}, where the action is calculated by sampling from said random distribution.

Actions are deemed to be in a continuous space of type :f32.

For simplicity in the implementation of the Dual Q implementation, :critic_params is vectorized with the axis :critics with default size 2. Likewise, :critic_target_params and :critic_optimizer_state are also vectorized in the same way.

Vectorized axes from :random_key are still propagated normally throughout the agent state for parallel training.