mlx_autograd (mlx v0.1.0)

View Source

Summary

Functions

accumulate_gradients(GradList1, GradList2)

backward(Loss)

backward(Loss, Options)

checkpoint(Fun)

checkpoint(Fun, Args)

clip_grad_norm(Gradients, MaxNorm)

clip_grad_value(Gradients, ClipValue)

custom_vjp(Fun, FwdFun, BwdFun)

detach(Array)

disable_grad()

enable_grad()

get_grad(Array)

grad(Fun, Args)

grad(Fun, Args, Argnums)

grad(Fun, Args, Argnums, Options)

grad_grad(Fun, Args, Argnums)

hessian(Fun, Args)

hessian(Fun, Args, Argnums)

jacobian(Fun, Args)

jacobian(Fun, Args, Argnums)

jvp(Fun, Args, Tangents)

jvp(Fun, Args, Tangents, Argnums)

laplacian(Fun, Args)

no_grad(Fun)

set_grad(Array, Gradient)

stop_gradient(Array)

value_and_grad(Fun, Args)

value_and_grad(Fun, Args, Argnums)

vjp(Fun, Args)

vjp(Fun, Args, Argnums)

zero_grad(Arrays)