View Source Axon.Recurrent (Axon v0.1.0)
Functional implementations of common recurrent neural network routines.
Recurrent Neural Networks are commonly used for working with sequences of data where there is some level of dependence between outputs at different timesteps.
This module contains 3 RNN Cell functions and methods to "unroll" cells over an entire sequence. Each cell function returns a tuple:
{new_carry, output}
Where new_carry
is an updated carry state and output
is the output
for a singular timestep. In order to apply an RNN across multiple timesteps,
you need to use either static_unroll
or dynamic_unroll
(coming soon).
Unrolling an RNN is equivalent to a map_reduce
or scan
starting
from an initial carry state and ending with a final carry state and
an output sequence.
All of the functions in this module are implemented as
numerical functions and can be JIT or AOT compiled with
any supported Nx
compiler.
Link to this section Summary
Link to this section Functions
conv_lstm_cell(input, carry, input_kernel, hidden_kernel, bias, opts \\ [])
View SourceConvLSTM Cell.
When combined with Axon.Recurrent.*_unroll
, implements a
ConvLSTM-based RNN. More memory efficient than traditional LSTM.
options
Options
:strides
- convolution strides. Defaults to1
.:padding
- convolution padding. Defaults to:same
.
references
References
dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias)
View SourceDynamically unrolls an RNN.
Unrolls implement a scan
operation which applies a
transformation on the leading axis of input_sequence
carrying
some state. In this instance cell_fn
is an RNN cell function
such as lstm_cell
or gru_cell
.
This function will make use of an defn
while-loop such and thus
may be more efficient for long sequences.
gru_cell(input, carry, input_kernel, hidden_kernel, bias, gate_fn \\ &sigmoid/1, activation_fn \\ &tanh/1)
View SourceGRU Cell.
When combined with Axon.Recurrent.*_unroll
, implements a
GRU-based RNN. More memory efficient than traditional LSTM.
references
References
lstm_cell(input, carry, input_kernel, hidden_kernel, bias, gate_fn \\ &sigmoid/1, activation_fn \\ &tanh/1)
View SourceLSTM Cell.
When combined with Axon.Recurrent.*_unroll
, implements a
LSTM-based RNN. More memory efficient than traditional LSTM.
references
References
static_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias)
View SourceStatically unrolls an RNN.
Unrolls implement a scan
operation which applies a
transformation on the leading axis of input_sequence
carrying
some state. In this instance cell_fn
is an RNN cell function
such as lstm_cell
or gru_cell
.
This function inlines the unrolling of the sequence such that the entire operation appears as a part of the compilation graph. This makes it suitable for shorter sequences.