mlx_neural_engine (mlx v0.1.0)

View Source

MLX Neural Engine - INSANE AI/ML Pipeline for Apple Silicon This module implements a complete neural network training and inference engine using MLX with Erlang's fault-tolerance and concurrency

Summary

Functions

adam_optimizer(Network, Data, LearningRate, Options)

-spec adam_optimizer(reference(), term(), float(), map()) -> {ok, term()} | {error, term()}.

batch_inference(NetworkRef, Batches)

-spec batch_inference(reference(), [term()]) -> {ok, [term()]} | {error, term()}.

batch_normalization(Input, Options)

-spec batch_normalization(Input :: reference(), Options :: map()) -> {ok, reference()}.

Batch normalization

code_change(OldVsn, State, Extra)

-spec code_change(term(),
                  #state{networks :: term(),
                         optimizers :: term(),
                         training_jobs :: term(),
                         inference_servers :: term(),
                         distributed_nodes :: term()},
                  term()) ->
                     {ok,
                      #state{networks :: term(),
                             optimizers :: term(),
                             training_jobs :: term(),
                             inference_servers :: term(),
                             distributed_nodes :: term()}}.

convolution_2d(Input, Kernel, Options)

-spec convolution_2d(Input :: reference(), Kernel :: reference(), Options :: map()) -> {ok, reference()}.

2D Convolution operation

create_network(Architecture, Options)

-spec create_network(Architecture :: map(), Options :: map()) -> {ok, reference()} | {error, term()}.

Create a neural network with specified architecture

distributed_sgd(NetworkRef, Data, Nodes)

-spec distributed_sgd(NetworkRef :: reference(), Data :: term(), Nodes :: [node()]) -> {ok, map()}.

Distributed SGD training

dropout(Input, Rate)

-spec dropout(Input :: reference(), Rate :: float()) -> {ok, reference()}.

Dropout for regularization

federated_learning(Networks, Data, Rounds)

-spec federated_learning([reference()], term(), integer()) -> {ok, term()}.

gradient_descent(Network, Data, LearningRate)

-spec gradient_descent(reference(), term(), float()) -> {ok, term()} | {error, term()}.

handle_call(Request, From, State)

-spec handle_call(term(),
                  {pid(), term()},
                  #state{networks :: term(),
                         optimizers :: term(),
                         training_jobs :: term(),
                         inference_servers :: term(),
                         distributed_nodes :: term()}) ->
                     {reply,
                      term(),
                      #state{networks :: term(),
                             optimizers :: term(),
                             training_jobs :: term(),
                             inference_servers :: term(),
                             distributed_nodes :: term()}}.

handle_cast(Msg, State)

-spec handle_cast(term(),
                  #state{networks :: term(),
                         optimizers :: term(),
                         training_jobs :: term(),
                         inference_servers :: term(),
                         distributed_nodes :: term()}) ->
                     {noreply,
                      #state{networks :: term(),
                             optimizers :: term(),
                             training_jobs :: term(),
                             inference_servers :: term(),
                             distributed_nodes :: term()}}.

handle_info(Info, State)

-spec handle_info(term(),
                  #state{networks :: term(),
                         optimizers :: term(),
                         training_jobs :: term(),
                         inference_servers :: term(),
                         distributed_nodes :: term()}) ->
                     {noreply,
                      #state{networks :: term(),
                             optimizers :: term(),
                             training_jobs :: term(),
                             inference_servers :: term(),
                             distributed_nodes :: term()}}.

init(_)

-spec init([]) ->
              {ok,
               #state{networks :: term(),
                      optimizers :: term(),
                      training_jobs :: term(),
                      inference_servers :: term(),
                      distributed_nodes :: term()}} |
              {stop, term()}.

learning_rate_schedule(Epoch, InitialRate)

-spec learning_rate_schedule(integer(), float()) -> {ok, float()}.

load_model(Path)

-spec load_model(string()) -> {ok, reference()} | {error, term()}.

model_ensemble(Models, Strategy)

-spec model_ensemble([reference()], atom()) -> {ok, reference()} | {error, term()}.

model_serving(NetworkRef, Request, Options)

-spec model_serving(reference(), term(), map()) -> {ok, term()} | {error, term()}.

parameter_server(Data, Nodes)

-spec parameter_server(term(), [node()]) -> {ok, term()}.

predict(NetworkRef, Input)

-spec predict(NetworkRef :: reference(), Input :: term()) -> {ok, term()} | {error, term()}.

Make predictions with a trained network

save_model(ModelRef, Path)

-spec save_model(reference(), string()) -> {ok, term()} | {error, term()}.

start_link()

-spec start_link() -> {ok, pid()} | {error, term()}.

stop()

-spec stop() -> ok.

streaming_inference(NetworkRef, InputStream)

-spec streaming_inference(NetworkRef :: reference(), InputStream :: reference()) -> {ok, reference()}.

Streaming inference for real-time applications

terminate(Reason, State)

-spec terminate(term(),
                #state{networks :: term(),
                       optimizers :: term(),
                       training_jobs :: term(),
                       inference_servers :: term(),
                       distributed_nodes :: term()}) ->
                   ok.

train_network(NetworkRef, TrainingData, Options)

-spec train_network(NetworkRef :: reference(), TrainingData :: term(), Options :: map()) ->
                       {ok, map()} | {error, term()}.

Train a neural network

transformer_attention(Query, Key, Value)

-spec transformer_attention(Query :: reference(), Key :: reference(), Value :: reference()) ->
                               {ok, reference()}.

Transformer attention mechanism