View Source EXLA (EXLA v0.2.0)
Google's XLA (Accelerated Linear Algebra) compiler/backend for Nx.
It supports just-in-time (JIT) compilation to GPU (both CUDA and ROCm) and TPUs.
configuration
Configuration
in-projects
In projects
EXLA works both as a backend for Nx
tensors and an optimized Nx.Defn
compiler.
To enable both globally, add a config/config.exs
(or config/ENV.exs
) with the following:
import Config
config :nx, :default_backend, EXLA.Backend
config :nx, :default_defn_options, [compiler: EXLA]
Now you can use Nx
as usual and it will use EXLA
by default.
You can also use cuda/rocm/tpu as the target by setting :client
option
in both configuration:
import Config
config :nx, :default_backend, {EXLA.Backend, client: :cuda}
config :nx, :default_defn_options, [compiler: EXLA, client: :cuda]
To use GPUs/TPUs, you must also set the appropriate value for the
XLA_TARGET
environment
variable. For CUDA, setting ELIXIR_ERL_OPTIONS="+sssdio 128"
is also
required on more complex operations to increase CUDA's compiler stack size.
in-scripts-notebooks
In scripts/notebooks
The simplest way to configure EXLA in notebooks is by calling:
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
Then EXLA will pick the first platform available for the current
notebook user. From them on, you can use Nx
as usual and it will
use EXLA
by default.
As in the project configuration above, you must also set the appropriate
value for the XLA_TARGET
environment variable if you intend to use GPU/TPUs.
options
Options
The options accepted by EXLA configuration are:
:client
- an atom representing the client to use. Defaults to:host
. See "Clients" section:device_id
- the default device id to run the computation on. Defaults to the:default_device_id
on the client
clients
Clients
The EXLA
library uses a client for compiling and executing code.
Those clients are typically bound to a platform, such as CPU or
GPU.
Those clients are singleton resources on Google's XLA library, therefore they are treated as a singleton resource on this library too. EXLA ships with the client configuration for each supported platform, which would be the equivalent to this:
config :exla, :clients,
host: [platform: :host],
cuda: [platform: :cuda],
rocm: [platform: :rocm],
tpu: [platform: :tpu]
Important! you should avoid using multiple clients for the same platform. If you have multiple clients per platform, they can race each other and fight for resources, such as memory. Therefore, we recommend developers to stick with the default clients above.
client-options
Client options
Each client configuration accepts the following options:
:platform
- the platform the client runs on. It can be:host
(CPU),:cuda
,:rocm
, or:tpu
.:default_device_id
- the default device ID to run on. For example, if you have two GPUs, you can choose a different one as the default. Defaults to device 0 (the first device).:preallocate
- if the memory should be preallocated on GPU devices. Defaults totrue
.:memory_fraction
- how much memory of a GPU device to allocate. Defaults to0.9
.
gpu-runtime-issues
GPU Runtime Issues
GPU Executions run in dirty IO threads, which have a considerable smaller stack size than regular scheduler threads. This may lead to problems with certain CUDA or cuDNN versions, leading to segmentation fails. In a development environment, it is suggested to set:
ELIXIR_ERL_OPTIONS="+sssdio 128"
To increase the stack size of dirty IO threads from 40 kilowords to
128 kilowords. In a release, you can set this flag in your vm.args
.
device-allocation
Device allocation
EXLA also ships with a EXLA.Backend
that allows data to be explicitly
allocated on the EXLA device. You can create tensors with EXLA.Backend
directly:
Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
or you can configure EXLA.Backend
as the default backend, so that
all tensors are allocated on the EXLA device by default.
In some cases you may want to explicitly move an existing tensor to the device:
tensor = Nx.tensor([1, 2, 3, 4], backend: Nx.BinaryBackend)
Nx.backend_transfer(tensor, EXLA.Backend)
Note that you can use regular Nx
operations, so the following works:
tensor = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
Nx.sum(tensor)
Under the hood, EXLA will create a computation for the sum operation
and invoke it on the device. This is essentially an "eager mode"
that provides acceleration during prototyping. However, eventually
you should wrap your computations in a defn
to utilize the full
performance of JIT.
docker-considerations
Docker considerations
EXLA should run fine on Docker with one important consideration: you must not start the Erlang VM as the root process in Docker. That's because when the Erlang VM runs as root, it has to manage all child programs.
At the same time, Google XLA's shells out to child program during compilation and it must retain control over how child programs terminate.
To address this, simply make sure you wrap the Erlang VM in another process, such as the shell one. In other words, if you are using releases, instead of this:
RUN path/to/release start
do this:
RUN sh -c "path/to/release start"
If you are using Mix inside your Docker containers, instead of this:
RUN mix run
do this:
RUN sh -c "mix run"
Alternatively, you can pass the --init
flag to docker run
, so
it runs an init
inside the container that forwards signals and
reaps processes.
Link to this section Summary
Functions
A shortcut for Nx.Defn.jit/3
with the EXLA compiler.
Checks if the JIT compilation of function with args is cached.
Sets the global defn options to the EXLA compiler with the preferred client based on their availability.
Starts streaming the given anonymous function with just-in-time compilation.
Checks if the JIT compilation of stream with args is cached.
Link to this section Functions
A shortcut for Nx.Defn.jit/3
with the EXLA compiler.
iex> EXLA.jit(&Nx.add(&1, &1), [Nx.tensor([1, 2, 3])])
#Nx.Tensor<
s64[3]
[2, 4, 6]
>
See the moduledoc for options.
Checks if the JIT compilation of function with args is cached.
Note that hooks are part of the cache, and therefore they must be included in the options.
examples
Examples
iex> fun = fn a, b -> Nx.add(a, b) end
iex> left = Nx.tensor(1, type: {:u, 8})
iex> right = Nx.tensor([1, 2, 3], type: {:u, 16})
iex> EXLA.jit(fun, [left, right])
iex> EXLA.jit_cached?(fun, [left, right])
true
iex> EXLA.jit_cached?(fun, [left, Nx.tensor([1, 2, 3, 4], type: {:u, 16})])
false
Sets the global defn options to the EXLA compiler with the preferred client based on their availability.
This function is typically invoked at the top of scripts and code
notebooks which might be potentially executed from multiple platforms.
Do not invoke this function during runtime, as it changes Nx.Defn
options globally. If you have a specific client that you want to use
throughout your project, use configuration files instead:
import Config
config :nx, :default_backend, {EXLA.Backend, client: :cuda}
config :nx, :default_defn_options, [compiler: EXLA, client: :cuda]
examples
Examples
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
The above will try to find the first client available and set
the EXLA
compiler with the client as the compilers for Nx.Defn
.
If no client is found, EXLA
is not set as compiler at all,
therefore it is common to add :host
as the last option.
If additional options are given, they are given as compiler options:
EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])
To use the GPU or TPUs, don't forget to also set the appropriate value
for the XLA_TARGET
environment variable.
Starts streaming the given anonymous function with just-in-time compilation.
At least two arguments are expected:
The first argument is a tensor template of the data to be streamed in
The second argument is a tensor with the stream initial state
The streaming function must return a two element tuple, the first element is the data to be sent and the second is the accumulator.
For each streamed chunk, you must call Nx.Stream.send/2
and
Nx.Stream.recv/1
. You don't need to call recv
immediately
after send
, but doing so can be a useful mechanism to provide
backpressure. Once all chunks are sent, you must use Nx.Stream.done/1
to receive the accumulated result. Let's see an example:
defmodule Streamed do
import Nx.Defn
defn sum(tensor, acc) do
{acc, tensor + acc}
end
end
Now let's invoke it:
stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 64}), 0])
for i <- 1..5 do
Nx.Stream.send(stream, i)
IO.inspect {:chunk, Nx.Stream.recv(stream)}
end
IO.inspect {:result, Nx.Stream.done(stream)}
It will print:
{:chunk, 0}
{:chunk, 1}
{:chunk, 2}
{:chunk, 3}
{:chunk, 4}
{:result, 5}
Note: While any process can call Nx.Stream.send/2
, EXLA
expects the process that starts the streaming to be the one
calling Nx.Stream.recv/1
and Nx.Stream.done/1
.
Checks if the JIT compilation of stream with args is cached.
Note that hooks are part of the cache, and therefore they must be included in the options.
examples
Examples
iex> left = Nx.tensor(1, type: {:u, 8})
iex> right = Nx.tensor([1, 2, 3], type: {:u, 16})
iex> fun = fn x, acc -> {acc, Nx.add(x, acc)} end
iex> stream = EXLA.stream(fun, [left, right])
iex> Nx.Stream.done(stream)
iex> EXLA.stream_cached?(fun, [left, right])
true
iex> EXLA.stream_cached?(fun, [left, Nx.tensor([1, 2, 3, 4], type: {:u, 16})])
false