View Source Bumblebee.Text.Generation behaviour (Bumblebee v0.3.1)
An interface for language models supporting sequence generation.
Summary
Callbacks
Initializes an opaque cache input for iterative inference.
Traverses all batched tensors in the cache.
Functions
Builds a numerical definition that generates sequences of tokens using the given language model.
Initializes an opaque cache input for iterative inference.
Calls fun
for every batched tensor in the cache.
Types
@type cache() :: Nx.Tensor.t() | Nx.Container.t()
Callbacks
@callback init_cache( spec :: Bumblebee.ModelSpec.t(), batch_size :: pos_integer(), max_length :: pos_integer(), inputs :: map() ) :: cache()
Initializes an opaque cache input for iterative inference.
@callback traverse_cache( spec :: Bumblebee.ModelSpec.t(), cache(), (Nx.Tensor.t() -> Nx.Tensor.t()) ) :: cache()
Traverses all batched tensors in the cache.
This function is used when the cache needs to be inflated or deflated for a different batch size.
Functions
@spec build_generate( Axon.t(), Bumblebee.ModelSpec.t(), Bumblebee.Text.GenerationConfig.t(), keyword() ) :: (params :: map(), inputs :: map() -> Nx.t())
Builds a numerical definition that generates sequences of tokens using the given language model.
The model should be either a decoder or an encoder-decoder. The tokens are generated by iterative inference using the decoder (autoregression), until the termination criteria are met.
In case of encoder-decoder models, the corresponding encoder is run only once and the intermediate state is reused during all iterations.
The generation is controlled by a number of options given as
%Bumblebee.Text.GenerationConfig{}
, see the corresponding docs
for more details.
Options
:seed
- random seed to use when sampling. By default the current timestamp is used
@spec init_cache(Bumblebee.ModelSpec.t(), pos_integer(), pos_integer(), map()) :: cache()
Initializes an opaque cache input for iterative inference.
@spec traverse_cache( Bumblebee.ModelSpec.t(), cache(), (Nx.Tensor.t() -> Nx.Tensor.t()) ) :: cache()
Calls fun
for every batched tensor in the cache.