mlx
v0.1.0
Search documentation of mlx
Settings
mlx_nn
(mlx v0.1.0)
View Source
Summary
Functions
adaptive_avg_pool1d(Input, OutputSize)
adaptive_avg_pool2d(Input, OutputSize)
alpha_dropout(Input, Rate)
avg_pool1d(Input, KernelSize)
avg_pool1d(Input, KernelSize, Stride)
avg_pool2d(Input, KernelSize)
avg_pool2d(Input, KernelSize, Stride)
batch_norm(Input, Weight)
batch_norm(Input, Weight, Bias)
batch_norm(Input, Weight, Bias, Eps)
conv1d(Input, Weight, Bias, Stride)
conv1d(Input, Weight, Bias, Stride, Padding)
conv1d(Input, Weight, Bias, Stride, Padding, Dilation)
conv2d(Input, Weight, Bias, Stride)
conv2d(Input, Weight, Bias, Stride, Padding)
conv2d(Input, Weight, Bias, Stride, Padding, Dilation)
conv3d(Input, Weight, Bias, Stride)
conv3d(Input, Weight, Bias, Stride, Padding)
conv3d(Input, Weight, Bias, Stride, Padding, Dilation)
conv_transpose1d(Input, Weight, Bias, Stride)
conv_transpose2d(Input, Weight, Bias, Stride)
conv_transpose3d(Input, Weight, Bias, Stride)
cross_attention(Query, Key, Value, WQ, WK)
cross_attention(Query, Key, Value, WQ, WK, Options)
dense_block(Input, Layers, GrowthRate)
dense_block(Input, Layers, GrowthRate, Options)
dropout2d(Input, Rate)
dropout3d(Input, Rate)
dropout(Input, Rate)
dropout(Input, Rate, Training)
efficientnet_block(Input, ExpandRatio, KernelSize, Stride, SEReduction)
efficientnet_block(Input, ExpandRatio, KernelSize, Stride, SEReduction, Options)
embedding(Input, EmbeddingMatrix, VocabSize)
embedding(Input, EmbeddingMatrix, VocabSize, Options)
feature_alpha_dropout(Input, Rate)
global_avg_pool(Input)
global_max_pool(Input)
group_norm(Input, NumGroups, Weight)
group_norm(Input, NumGroups, Weight, Bias)
gru_cell(Input, Hidden, Weights)
gru_cell(Input, Hidden, Weights, Options)
identity_init(Shape)
instance_norm(Input, Weight)
instance_norm(Input, Weight, Bias)
kaiming_normal(Shape)
kaiming_uniform(Shape)
layer_norm(Input, Weight)
layer_norm(Input, Weight, Bias)
layer_norm(Input, Weight, Bias, Eps)
linear(Input, Weight, Bias)
linear(Input, Weight, Bias, Options)
lstm_cell(Input, Hidden, Cell, Weights)
lstm_cell(Input, Hidden, Cell, Weights, Options)
max_pool1d(Input, KernelSize)
max_pool1d(Input, KernelSize, Stride)
max_pool2d(Input, KernelSize)
max_pool2d(Input, KernelSize, Stride)
multi_head_attention(Query, Key, Value, WQ, WK, WV)
multi_head_attention(Query, Key, Value, WQ, WK, WV, Options)
orthogonal_init(Shape)
positional_encoding(MaxLen, DModel, Base)
positional_encoding(MaxLen, DModel, Base, Dtype)
residual_block(Input, Layers, Activation)
residual_block(Input, Layers, Activation, Options)
rms_norm(Input, Weight)
rms_norm(Input, Weight, Eps)
rnn_cell(Input, Hidden, Weights)
rnn_cell(Input, Hidden, Weights, Options)
scaled_dot_product_attention(Query, Key, Value)
scaled_dot_product_attention(Query, Key, Value, Mask)
self_attention(Input, WQ, WK)
self_attention(Input, WQ, WK, Options)
squeeze_excitation(Input, ReductionRatio)
squeeze_excitation(Input, ReductionRatio, Activation)
swin_transformer_block(Input, WindowSize, Attention, MLP, LayerNorm1)
swin_transformer_block(Input, WindowSize, Attention, MLP, LayerNorm1, LayerNorm2)
transformer_decoder_layer(Input, Memory, SelfAttnW, CrossAttnW, FFNWeights, LayerNorms, Dropout)
transformer_decoder_layer(Input, Memory, SelfAttnW, CrossAttnW, FFNWeights, LayerNorms, Dropout, Options)
transformer_encoder_layer(Input, SelfAttnW, FFNWeights, LayerNorm1, LayerNorm2, Dropout)
transformer_encoder_layer(Input, SelfAttnW, FFNWeights, LayerNorm1, LayerNorm2, Dropout, Options)
vision_transformer_block(Input, Attention, MLP, LayerNorm1)
vision_transformer_block(Input, Attention, MLP, LayerNorm1, LayerNorm2)
xavier_normal(Shape)
xavier_uniform(Shape)
Functions
adaptive_avg_pool1d(Input, OutputSize)
adaptive_avg_pool2d(Input, OutputSize)
alpha_dropout(Input, Rate)
avg_pool1d(Input, KernelSize)
avg_pool1d(Input, KernelSize, Stride)
avg_pool2d(Input, KernelSize)
avg_pool2d(Input, KernelSize, Stride)
batch_norm(Input, Weight)
batch_norm(Input, Weight, Bias)
batch_norm(Input, Weight, Bias, Eps)
conv1d(Input, Weight, Bias, Stride)
conv1d(Input, Weight, Bias, Stride, Padding)
conv1d(Input, Weight, Bias, Stride, Padding, Dilation)
conv2d(Input, Weight, Bias, Stride)
conv2d(Input, Weight, Bias, Stride, Padding)
conv2d(Input, Weight, Bias, Stride, Padding, Dilation)
conv3d(Input, Weight, Bias, Stride)
conv3d(Input, Weight, Bias, Stride, Padding)
conv3d(Input, Weight, Bias, Stride, Padding, Dilation)
conv_transpose1d(Input, Weight, Bias, Stride)
conv_transpose2d(Input, Weight, Bias, Stride)
conv_transpose3d(Input, Weight, Bias, Stride)
cross_attention(Query, Key, Value, WQ, WK)
cross_attention(Query, Key, Value, WQ, WK, Options)
dense_block(Input, Layers, GrowthRate)
dense_block(Input, Layers, GrowthRate, Options)
dropout2d(Input, Rate)
dropout3d(Input, Rate)
dropout(Input, Rate)
dropout(Input, Rate, Training)
efficientnet_block(Input, ExpandRatio, KernelSize, Stride, SEReduction)
efficientnet_block(Input, ExpandRatio, KernelSize, Stride, SEReduction, Options)
embedding(Input, EmbeddingMatrix, VocabSize)
embedding(Input, EmbeddingMatrix, VocabSize, Options)
feature_alpha_dropout(Input, Rate)
global_avg_pool(Input)
global_max_pool(Input)
group_norm(Input, NumGroups, Weight)
group_norm(Input, NumGroups, Weight, Bias)
gru_cell(Input, Hidden, Weights)
gru_cell(Input, Hidden, Weights, Options)
identity_init(Shape)
instance_norm(Input, Weight)
instance_norm(Input, Weight, Bias)
kaiming_normal(Shape)
kaiming_uniform(Shape)
layer_norm(Input, Weight)
layer_norm(Input, Weight, Bias)
layer_norm(Input, Weight, Bias, Eps)
linear(Input, Weight, Bias)
linear(Input, Weight, Bias, Options)
lstm_cell(Input, Hidden, Cell, Weights)
lstm_cell(Input, Hidden, Cell, Weights, Options)
max_pool1d(Input, KernelSize)
max_pool1d(Input, KernelSize, Stride)
max_pool2d(Input, KernelSize)
max_pool2d(Input, KernelSize, Stride)
multi_head_attention(Query, Key, Value, WQ, WK, WV)
multi_head_attention(Query, Key, Value, WQ, WK, WV, Options)
orthogonal_init(Shape)
positional_encoding(MaxLen, DModel, Base)
positional_encoding(MaxLen, DModel, Base, Dtype)
residual_block(Input, Layers, Activation)
residual_block(Input, Layers, Activation, Options)
rms_norm(Input, Weight)
rms_norm(Input, Weight, Eps)
rnn_cell(Input, Hidden, Weights)
rnn_cell(Input, Hidden, Weights, Options)
scaled_dot_product_attention(Query, Key, Value)
scaled_dot_product_attention(Query, Key, Value, Mask)
self_attention(Input, WQ, WK)
self_attention(Input, WQ, WK, Options)
squeeze_excitation(Input, ReductionRatio)
squeeze_excitation(Input, ReductionRatio, Activation)
swin_transformer_block(Input, WindowSize, Attention, MLP, LayerNorm1)
swin_transformer_block(Input, WindowSize, Attention, MLP, LayerNorm1, LayerNorm2)
transformer_decoder_layer(Input, Memory, SelfAttnW, CrossAttnW, FFNWeights, LayerNorms, Dropout)
transformer_decoder_layer(Input, Memory, SelfAttnW, CrossAttnW, FFNWeights, LayerNorms, Dropout, Options)
transformer_encoder_layer(Input, SelfAttnW, FFNWeights, LayerNorm1, LayerNorm2, Dropout)
transformer_encoder_layer(Input, SelfAttnW, FFNWeights, LayerNorm1, LayerNorm2, Dropout, Options)
vision_transformer_block(Input, Attention, MLP, LayerNorm1)
vision_transformer_block(Input, Attention, MLP, LayerNorm1, LayerNorm2)
xavier_normal(Shape)
xavier_uniform(Shape)