Nx.Backend implementation backed by Apple's MLX.
Public API
Users rarely call functions on this module directly. Install it as
the default backend (or the per-tensor backend: opt) and Nx does
the dispatch:
Nx.global_default_backend({Emily.Backend, device: :gpu})
Nx.tensor([[1.0, 2.0], [3.0, 4.0]])
|> Nx.sum()
|> Nx.to_flat_list()
# => [10.0]Every function defined here implements a callback from the
Nx.Backend behaviour (see @impl true in the source); they form
the Nx dispatch table, not a user-facing API. The handful of
fast_* functions are the dispatch targets for optional-expression
nodes emitted by Emily.Fast — again internal.
Options
:device—:gpu(default) or:cpu. Stored per-tensor; MLX dispatches the computation on that device.
Divergences from Nx.BinaryBackend
{:f, 64}is not supported — Metal cannot execute f64. Allocations at f64 raiseArgumentError; cast to f32 instead.from_pointer,to_pointer,population_count,count_leading_zeros, and interior-paddingpadraise — MLX has no primitive.qrwithmode: :completefalls back toNx.BinaryBackend(MLX only supports reduced QR).determinantuses Nx's default implementation, which calls the nativelufor matrices larger than 3×3.quotientuses MLXfloor_dividesemantics (floor toward -∞ rather than Nx's truncate-toward-zero). For non-negative integer operands the results agree; mixed-sign inputs diverge by one.- Duplicate-index
indexed_put: MLX's underlying scatter is unordered on duplicates, whileNx.BinaryBackendis deterministic last-write.indexed_addis commutative so duplicates accumulate identically on both backends.
Fallbacks
A handful of ops have no direct MLX primitive and fall back to
Nx.BinaryBackend via a transparent round-trip
(from_pointer-free, one memcpy each way). The fallback emits
[:emily, :fallback, *] telemetry spans; see Emily.Telemetry for
the full catalogue and for opt-in one-shot warnings.
Every op below has a native MLX path for its hot shape/dtype and falls back only when the input hits the listed guard:
gather— indices tensor not in the{batch…, rank_of_axes}layout MLX gather accepts.cumulative_sum/cumulative_product/cumulative_max/cumulative_min— whenaxisis not the last axis (MLX's factoring raises on some interior-axis views).dot— batched dot on integer / pred types (MLX matmul is float-only). The non-batched tensordot path handles ints natively.conv—batch_group_size > 1, or complex-typed.reduce— always, since the reducer is a user-supplied BEAM function that can't be JITed into Metal.window_reduce— same reason. The fixedwindow_sum/window_product/window_max/window_minvariants all run native.indexed_add/indexed_put— indices tensor not in MLX's native scatter layout.qrwithmode: :complete.mode: :reducedis native.
Debug assertions
Compile-time flags :debug_bounds_check and :debug_detect_nan_inf
re-enable runtime assertions on hot paths. Both default to false
with zero cost. See Emily moduledoc for details.