View Source NxIREE (NxIREE v0.0.1)

IREE compiler and runtime bindings for Nx.

NxIREE provides an Nx.Defn compiler which runs on IREE. The following example shows how we can compile a function to run on the CPU:

f = &Nx.add/2
compiler_flags = ["--iree-hal-target-backends=llvm-cpu", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
opts = [compiler: NxIREE.Compiler, iree_compiler_flags: compiler_flags, iree_runtime_options: [device: "local-sync://"])
Nx.Defn.compile(f, [Nx.template({1}, :f32), Nx.template({3}, :f32)], opts)

And this next example compiles the same function for running on the Apple GPU through Metal:

f = &Nx.add/2
compiler_flags = ["--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
opts = [compiler: NxIREE.Compiler, iree_compiler_flags: compiler_flags, iree_runtime_options: [device: "metal://"])
Nx.Defn.compile(f, [Nx.template({1}, :f32), Nx.template({3}, :f32)], opts)

NxIREE.Compiler also provides a to_bytecode function which outputs the bytecode for usage with embedded devices, such as the iOS devices usable through LiveNxIREE.

Summary

Functions

Calls a function in the given module with the provided Nx inputs.

Compiles the given MLIR module with the given list of flags.

Lists all devices available for running IREE modules.

Lists all devices available in a given driver for running IREE modules.

Lists all drivers available for running IREE modules.

Functions

Link to this function

call(module, inputs, opts \\ [])

View Source

Calls a function in the given module with the provided Nx inputs.

Options

# :function - The name of the function to call in the module. If not provided, will default to "main".

  • :device - The device to run the module on. If not provided, will default to known GPU devices (CUDA, ROCm, Metal, Vulkan) over others. Valid values can be obtained through list_devices/0 or list_devices/1.
Link to this function

compile(mlir_module, flags, opts \\ [])

View Source

Compiles the given MLIR module with the given list of flags.

Returns the bytecode for the compiled module.

Examples

iex> mlir_module = """
...> func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
...>   %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
...>   return %0 : tensor<4xf32>
...> }
...>"""
iex> flags = ["--iree-hal-target-backends=llvm-cpu", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
iex> NxIREE.compile(mlir_module, flags)

Lists all devices available for running IREE modules.

@spec list_devices(String.t()) :: {:ok, [String.t()]}
@spec list_devices(String.t()) :: {:ok, [String.t()]} | {:error, :unknown_driver}

Lists all devices available in a given driver for running IREE modules.

Valid drivers can be obtained through list_drivers/0.

@spec list_drivers() :: {:ok, [String.t()]} | {:error, :unknown_driver}

Lists all drivers available for running IREE modules.