viva_math/autodiff_reverse

Reverse-mode automatic differentiation via a computation tape.

Forward-mode AD (viva_math/autodiff) is optimal when the input dimension is small relative to the output (e.g. one-variable gradients). Reverse-mode AD (“backprop”) is optimal for the opposite case: scalar output, high-dimensional input — exactly the regime of neural networks and gradient-based MAP inference.

How it works

  1. The forward pass builds a directed acyclic graph (the “tape”) that records every elementary operation and its operands.
  2. The backward pass walks the tape in reverse, applying the chain rule to accumulate ∂output/∂node for every node.

All inputs that share a tape are differentiated in a single backward pass, regardless of how many there are — that’s why reverse-mode AD is O(1) in input dimension for gradient computation.

Limitations of this implementation

Example

import viva_math/autodiff_reverse as ad

// f(x, y, z) = x² + 2·y·z
let tape = ad.empty_tape()
let #(x, tape) = ad.input(tape, 1.0)
let #(y, tape) = ad.input(tape, 2.0)
let #(z, tape) = ad.input(tape, 3.0)
let #(x_sq, tape) = ad.mul(tape, x, x)
let #(yz, tape) = ad.mul(tape, y, z)
let #(yz2, tape) = ad.scale(tape, yz, 2.0)
let #(out, tape) = ad.add(tape, x_sq, yz2)
let grads = ad.backward(tape, out)
// grads[x] = 2x = 2.0
// grads[y] = 2z = 6.0
// grads[z] = 2y = 4.0

Types

One node on the tape. Public for tape introspection. Single tape node — pairs a forward value with the Op that produced it. Opaque (the integrity of the tape depends on this not being constructed outside the module).

pub opaque type Node

Reference to a node in the tape — an opaque handle.

pub type NodeId =
  Int

Op

opaque </>

What kind of operation produced this node. Public because it appears in the body of Node, which is reachable from the public Tape. Algebraic operation recorded on each tape node. Opaque — callers should build expressions via the high-level forward operations (add, mul, …) and consume gradients via backward + grad_of. Direct construction of Op values would let callers fabricate inconsistent tapes.

pub opaque type Op

Computation tape — append-only graph of forward computations. Opaque (callers should treat it as a token threaded through forward ops).

pub opaque type Tape

Values

pub fn add(tape: Tape, a: Int, b: Int) -> #(Int, Tape)

z = a + b. Local: ∂z/∂a = 1, ∂z/∂b = 1.

pub fn backward(tape: Tape, output: Int) -> dict.Dict(Int, Float)

Compute ∂output/∂node for every node, given a scalar output node.

Returns a Dict mapping each NodeId to its accumulated gradient. The output node itself has gradient 1.0.

pub fn cos(tape: Tape, a: Int) -> #(Int, Tape)

z = cos(a). Local: ∂z/∂a = −sin(a).

pub fn div(tape: Tape, a: Int, b: Int) -> #(Int, Tape)

z = a / b. Local: ∂z/∂a = 1/b, ∂z/∂b = −a/b².

pub fn empty_tape() -> Tape

Start a new empty computation tape.

pub fn exp(tape: Tape, a: Int) -> #(Int, Tape)

z = exp(a). Local: ∂z/∂a = exp(a) = z.

pub fn grad_of(grads: dict.Dict(Int, Float), input: Int) -> Float

Extract gradient ∂output/∂input from the backward result.

pub fn gradients(
  inputs: List(Float),
  build: fn(Tape, List(Int)) -> #(Int, Tape),
) -> List(Float)

Convenience: gradient of a scalar function with respect to many inputs.

Runs the function on a fresh tape, performs the backward pass, and returns the gradient list aligned with the input list.

pub fn input(tape: Tape, value: Float) -> #(Int, Tape)

Register a new input variable on the tape.

pub fn ln(tape: Tape, a: Int) -> #(Int, Tape)

z = ln(a). Local: ∂z/∂a = 1/a. Caller must ensure a > 0.

pub fn mul(tape: Tape, a: Int, b: Int) -> #(Int, Tape)

z = a · b. Local: ∂z/∂a = b, ∂z/∂b = a.

pub fn neg(tape: Tape, a: Int) -> #(Int, Tape)

z = −a. Local: ∂z/∂a = −1.

pub fn pow(tape: Tape, a: Int, n: Float) -> #(Int, Tape)

z = aⁿ (real exponent n). Local: ∂z/∂a = n·aⁿ⁻¹.

pub fn scale(tape: Tape, a: Int, s: Float) -> #(Int, Tape)

z = s · a with s a runtime constant. Local: ∂z/∂a = s.

pub fn sigmoid(tape: Tape, a: Int) -> #(Int, Tape)

z = σ(a). Local: ∂z/∂a = σ(a)·(1 − σ(a)) = z·(1 − z).

pub fn sin(tape: Tape, a: Int) -> #(Int, Tape)

z = sin(a). Local: ∂z/∂a = cos(a).

pub fn sub(tape: Tape, a: Int, b: Int) -> #(Int, Tape)

z = a − b. Local: ∂z/∂a = 1, ∂z/∂b = −1.

pub fn tanh(tape: Tape, a: Int) -> #(Int, Tape)

z = tanh(a). Local: ∂z/∂a = 1 − tanh²(a) = 1 − z².

pub fn value(tape: Tape, id: Int) -> Float

Read the forward value of a node.

Search Document