View Source NxIREE (NxIREE v0.0.1)
IREE compiler and runtime bindings for Nx.
NxIREE provides an Nx.Defn
compiler which runs on IREE.
The following example shows how we can compile a function to run on the CPU:
f = &Nx.add/2
compiler_flags = ["--iree-hal-target-backends=llvm-cpu", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
opts = [compiler: NxIREE.Compiler, iree_compiler_flags: compiler_flags, iree_runtime_options: [device: "local-sync://"])
Nx.Defn.compile(f, [Nx.template({1}, :f32), Nx.template({3}, :f32)], opts)
And this next example compiles the same function for running on the Apple GPU through Metal:
f = &Nx.add/2
compiler_flags = ["--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
opts = [compiler: NxIREE.Compiler, iree_compiler_flags: compiler_flags, iree_runtime_options: [device: "metal://"])
Nx.Defn.compile(f, [Nx.template({1}, :f32), Nx.template({3}, :f32)], opts)
NxIREE.Compiler
also provides a to_bytecode
function which outputs the bytecode for usage with embedded devices,
such as the iOS devices usable through LiveNxIREE
.
Summary
Functions
Calls a function in the given module with the provided Nx inputs.
Compiles the given MLIR module with the given list of flags.
Lists all devices available for running IREE modules.
Lists all devices available in a given driver for running IREE modules.
Lists all drivers available for running IREE modules.
Functions
Calls a function in the given module with the provided Nx inputs.
Options
# :function
- The name of the function to call in the module. If not provided, will default to "main"
.
:device
- The device to run the module on. If not provided, will default to known GPU devices (CUDA, ROCm, Metal, Vulkan) over others. Valid values can be obtained throughlist_devices/0
orlist_devices/1
.
Compiles the given MLIR module with the given list of flags.
Returns the bytecode for the compiled module.
Examples
iex> mlir_module = """
...> func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
...> %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
...> return %0 : tensor<4xf32>
...> }
...>"""
iex> flags = ["--iree-hal-target-backends=llvm-cpu", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
iex> NxIREE.compile(mlir_module, flags)
Lists all devices available for running IREE modules.
@spec list_devices(String.t()) :: {:ok, [String.t()]}
@spec list_devices(String.t()) :: {:ok, [String.t()]} | {:error, :unknown_driver}
Lists all devices available in a given driver for running IREE modules.
Valid drivers can be obtained through list_drivers/0
.
@spec list_drivers() :: {:ok, [String.t()]} | {:error, :unknown_driver}
Lists all drivers available for running IREE modules.