Emily.Backend (emily v0.3.2)

Copy Markdown View Source

Nx.Backend implementation backed by Apple's MLX.

Public API

Users rarely call functions on this module directly. Install it as the default backend (or the per-tensor backend: opt) and Nx does the dispatch:

Nx.global_default_backend({Emily.Backend, device: :gpu})

Nx.tensor([[1.0, 2.0], [3.0, 4.0]])
|> Nx.sum()
|> Nx.to_flat_list()
# => [10.0]

Every function defined here implements a callback from the Nx.Backend behaviour (see @impl true in the source); they form the Nx dispatch table, not a user-facing API. The handful of fast_* functions are the dispatch targets for optional-expression nodes emitted by Emily.Fast — again internal.

Options

  • :device:gpu (default) or :cpu. Stored per-tensor; MLX dispatches the computation on that device.

Divergences from Nx.BinaryBackend

  • {:f, 64} is not supported — Metal cannot execute f64. Allocations at f64 raise ArgumentError; cast to f32 instead.
  • from_pointer, to_pointer, population_count, count_leading_zeros, and interior-padding pad raise — MLX has no primitive.
  • qr with mode: :complete falls back to Nx.BinaryBackend (MLX only supports reduced QR). determinant uses Nx's default implementation, which calls the native lu for matrices larger than 3×3.
  • quotient uses MLX floor_divide semantics (floor toward -∞ rather than Nx's truncate-toward-zero). For non-negative integer operands the results agree; mixed-sign inputs diverge by one.
  • Duplicate-index indexed_put: MLX's underlying scatter is unordered on duplicates, while Nx.BinaryBackend is deterministic last-write. indexed_add is commutative so duplicates accumulate identically on both backends.

Fallbacks

A handful of ops have no direct MLX primitive and fall back to Nx.BinaryBackend via a transparent round-trip (from_pointer-free, one memcpy each way). The fallback emits [:emily, :fallback, *] telemetry spans; see Emily.Telemetry for the full catalogue and for opt-in one-shot warnings.

Every op below has a native MLX path for its hot shape/dtype and falls back only when the input hits the listed guard:

  • gather — indices tensor not in the {batch…, rank_of_axes} layout MLX gather accepts.
  • cumulative_sum / cumulative_product / cumulative_max / cumulative_min — when axis is not the last axis (MLX's factoring raises on some interior-axis views).
  • dot — batched dot on integer / pred types (MLX matmul is float-only). The non-batched tensordot path handles ints natively.
  • convbatch_group_size > 1, or complex-typed.
  • reduce — always, since the reducer is a user-supplied BEAM function that can't be JITed into Metal.
  • window_reduce — same reason. The fixed window_sum / window_product / window_max / window_min variants all run native.
  • indexed_add / indexed_put — indices tensor not in MLX's native scatter layout.
  • qr with mode: :complete. mode: :reduced is native.

Debug assertions

Compile-time flags :debug_bounds_check and :debug_detect_nan_inf re-enable runtime assertions on hot paths. Both default to false with zero cost. See Emily moduledoc for details.