mlx_distributed (mlx v0.2.0)

View Source

Summary

Functions

Functions

adaptive_communication(Tensor, CommunicationHistory)

adaptive_communication(Tensor, CommunicationHistory, Options)

all_gather(Tensor, Options)

all_gather(Tensor, CommGroup, Options)

all_reduce(Tensor, ReduceOp)

all_reduce(Tensor, ReduceOp, Options)

batched_communication(CommunicationOps, MaxBatchSize)

batched_communication(CommunicationOps, MaxBatchSize, Options)

broadcast(Tensor, Root, CommGroup)

broadcast(Tensor, Root, CommGroup, Options)

butterfly_all_reduce(Tensor, ReduceOp)

butterfly_all_reduce(Tensor, ReduceOp, Options)

checkpoint_communication(CommState, CheckpointPath)

communication_profiling(CommunicationOp, ProfilingOptions)

compressed_all_reduce(Tensor, ReduceOp, CompressionConfig)

compressed_all_reduce(Tensor, ReduceOp, CompressionConfig, Options)

create_communication_group(Ranks, Backend)

destroy_communication_group(GroupId)

device_aware_communication(Tensor, CommunicationOp, DeviceTopology)

device_aware_communication(Tensor, CommunicationOp, DeviceTopology, Options)

distributed_adam(Gradients, LearningRate, AdamConfig)

distributed_adamw(Gradients, LearningRate, AdamWConfig)

distributed_sgd(Gradients, LearningRate, Momentum)

dynamic_data_parallel(Model, DataLoader)

dynamic_data_parallel(Model, DataLoader, Options)

expert_parallel(Model, InputData, ExpertConfig)

expert_parallel(Model, InputData, ExpertConfig, Options)

fault_tolerant_all_reduce(Tensor, ReduceOp, FaultConfig)

fault_tolerant_all_reduce(Tensor, ReduceOp, FaultConfig, Options)

federated_averaging(LocalModels, AggregationWeights, FedConfig)

gradient_compression_ddp(Model, CompressionConfig, DDPConfig)

heterogeneous_all_reduce(Tensor, ReduceOp, DeviceConfig)

heterogeneous_all_reduce(Tensor, ReduceOp, DeviceConfig, Options)

hierarchical_all_reduce(Tensor, ReduceOp, CommGroup)

hierarchical_all_reduce(Tensor, ReduceOp, CommGroup, Options)

init_gloo_backend(GlooConfig)

init_mpi_backend(MPIConfig)

init_nccl_backend(NCCLConfig)

local_sgd(Gradients, LearningRate, Momentum, LocalSteps)

optimize_communication_plan(CommunicationOps, NetworkTopology)

overlapped_communication(ComputeFun, CommunicationOps, Tensor)

overlapped_communication(ComputeFun, CommunicationOps, Tensor, Options)

pipeline_parallel(Model, InputData, PipelineConfig)

pipeline_parallel(Model, InputData, PipelineConfig, Options)

pipelined_communication(TensorList, CommunicationOp, PipelineStages)

pipelined_communication(TensorList, CommunicationOp, PipelineStages, Options)

quantized_all_reduce(Tensor, ReduceOp, QuantizationConfig)

quantized_all_reduce(Tensor, ReduceOp, QuantizationConfig, Options)

recover_communication(CheckpointPath, RecoveryOptions)

reduce_scatter(Tensor, ReduceOp, CommGroup)

reduce_scatter(Tensor, ReduceOp, CommGroup, Options)

ring_all_reduce(Tensor, ReduceOp)

ring_all_reduce(Tensor, ReduceOp, Options)

sparse_all_reduce(SparseTensor, ReduceOp)

sparse_all_reduce(SparseTensor, ReduceOp, Options)

tensor_parallel(Model, InputData, TensorParallelConfig)

tensor_parallel(Model, InputData, TensorParallelConfig, Options)

torus_all_reduce(Tensor, ReduceOp, TorusTopology)

torus_all_reduce(Tensor, ReduceOp, TorusTopology, Options)

tree_all_reduce(Tensor, ReduceOp)

tree_all_reduce(Tensor, ReduceOp, Options)

zero_redundancy_optimizer(Optimizer, OptimizerConfig, ZeROConfig)