mlx_advanced (mlx v0.1.0)

View Source

Summary

Functions

angle(Array)

batch_norm(X)

batch_norm(X, Axis)

batch_norm(X, Axis, Eps)

cholesky(Array)

complex(Real, Imag)

conj(Array)

convolve(A, B)

correlate(A, B)

correlation(X, Y)

covariance(X, Y)

cross_entropy(Predictions, Targets)

cross_entropy(Predictions, Targets, Axis)

det(Array)

eig(Array)

einsum(Equation, Arrays)

einsum(Equation, Arrays, Options)

elu(X)

elu(X, Alpha)

fft2(Array)

fft(Array)

focal_loss(Predictions, Targets)

focal_loss(Predictions, Targets, Gamma)

focal_loss(Predictions, Targets, Gamma, Alpha)

gelu(X)

gradient_clip(Gradients, MaxNorm)

gradient_clip(Gradients, MaxNorm, NormType)

gradient_norm(Gradients)

group_norm(X, NumGroups)

group_norm(X, NumGroups, Eps)

histogram(Array, Bins)

histogram(Array, Bins, Options)

huber_loss(Predictions, Targets)

huber_loss(Predictions, Targets, Delta)

ifft2(Array)

ifft(Array)

imag(Array)

js_divergence(P, Q)

kl_divergence(P, Q)

layer_norm(X)

layer_norm(X, Axes)

layer_norm(X, Axes, Eps)

leaky_relu(X)

leaky_relu(X, Alpha)

matrix_rank(Array)

median(Array)

median(Array, Axis)

mish(X)

percentile(Array, Q)

percentile(Array, Q, Axis)

pinv(Array)

qr(Array)

real(Array)

rms_norm(X)

rms_norm(X, Eps)

selu(X)

spectral_norm(Weight)

svd(Array)

swish(X)

trace(Array)

weight_norm(Weight)