mlx_nn (mlx v0.1.0)

View Source

Summary

Functions

Functions

adaptive_avg_pool1d(Input, OutputSize)

adaptive_avg_pool2d(Input, OutputSize)

alpha_dropout(Input, Rate)

avg_pool1d(Input, KernelSize)

avg_pool1d(Input, KernelSize, Stride)

avg_pool2d(Input, KernelSize)

avg_pool2d(Input, KernelSize, Stride)

batch_norm(Input, Weight)

batch_norm(Input, Weight, Bias)

batch_norm(Input, Weight, Bias, Eps)

conv1d(Input, Weight, Bias, Stride)

conv1d(Input, Weight, Bias, Stride, Padding)

conv1d(Input, Weight, Bias, Stride, Padding, Dilation)

conv2d(Input, Weight, Bias, Stride)

conv2d(Input, Weight, Bias, Stride, Padding)

conv2d(Input, Weight, Bias, Stride, Padding, Dilation)

conv3d(Input, Weight, Bias, Stride)

conv3d(Input, Weight, Bias, Stride, Padding)

conv3d(Input, Weight, Bias, Stride, Padding, Dilation)

conv_transpose1d(Input, Weight, Bias, Stride)

conv_transpose2d(Input, Weight, Bias, Stride)

conv_transpose3d(Input, Weight, Bias, Stride)

cross_attention(Query, Key, Value, WQ, WK)

cross_attention(Query, Key, Value, WQ, WK, Options)

dense_block(Input, Layers, GrowthRate)

dense_block(Input, Layers, GrowthRate, Options)

dropout2d(Input, Rate)

dropout3d(Input, Rate)

dropout(Input, Rate)

dropout(Input, Rate, Training)

efficientnet_block(Input, ExpandRatio, KernelSize, Stride, SEReduction)

efficientnet_block(Input, ExpandRatio, KernelSize, Stride, SEReduction, Options)

embedding(Input, EmbeddingMatrix, VocabSize)

embedding(Input, EmbeddingMatrix, VocabSize, Options)

feature_alpha_dropout(Input, Rate)

global_avg_pool(Input)

global_max_pool(Input)

group_norm(Input, NumGroups, Weight)

group_norm(Input, NumGroups, Weight, Bias)

gru_cell(Input, Hidden, Weights)

gru_cell(Input, Hidden, Weights, Options)

identity_init(Shape)

instance_norm(Input, Weight)

instance_norm(Input, Weight, Bias)

kaiming_normal(Shape)

kaiming_uniform(Shape)

layer_norm(Input, Weight)

layer_norm(Input, Weight, Bias)

layer_norm(Input, Weight, Bias, Eps)

linear(Input, Weight, Bias)

linear(Input, Weight, Bias, Options)

lstm_cell(Input, Hidden, Cell, Weights)

lstm_cell(Input, Hidden, Cell, Weights, Options)

max_pool1d(Input, KernelSize)

max_pool1d(Input, KernelSize, Stride)

max_pool2d(Input, KernelSize)

max_pool2d(Input, KernelSize, Stride)

multi_head_attention(Query, Key, Value, WQ, WK, WV)

multi_head_attention(Query, Key, Value, WQ, WK, WV, Options)

orthogonal_init(Shape)

positional_encoding(MaxLen, DModel, Base)

positional_encoding(MaxLen, DModel, Base, Dtype)

residual_block(Input, Layers, Activation)

residual_block(Input, Layers, Activation, Options)

rms_norm(Input, Weight)

rms_norm(Input, Weight, Eps)

rnn_cell(Input, Hidden, Weights)

rnn_cell(Input, Hidden, Weights, Options)

scaled_dot_product_attention(Query, Key, Value)

scaled_dot_product_attention(Query, Key, Value, Mask)

self_attention(Input, WQ, WK)

self_attention(Input, WQ, WK, Options)

squeeze_excitation(Input, ReductionRatio)

squeeze_excitation(Input, ReductionRatio, Activation)

swin_transformer_block(Input, WindowSize, Attention, MLP, LayerNorm1)

swin_transformer_block(Input, WindowSize, Attention, MLP, LayerNorm1, LayerNorm2)

transformer_decoder_layer(Input, Memory, SelfAttnW, CrossAttnW, FFNWeights, LayerNorms, Dropout)

transformer_decoder_layer(Input, Memory, SelfAttnW, CrossAttnW, FFNWeights, LayerNorms, Dropout, Options)

transformer_encoder_layer(Input, SelfAttnW, FFNWeights, LayerNorm1, LayerNorm2, Dropout)

transformer_encoder_layer(Input, SelfAttnW, FFNWeights, LayerNorm1, LayerNorm2, Dropout, Options)

vision_transformer_block(Input, Attention, MLP, LayerNorm1)

vision_transformer_block(Input, Attention, MLP, LayerNorm1, LayerNorm2)

xavier_normal(Shape)

xavier_uniform(Shape)