mlx
v0.2.0
Search documentation of mlx
Settings
mlx_distributed
(mlx v0.2.0)
View Source
Summary
Functions
adaptive_communication(Tensor, CommunicationHistory)
adaptive_communication(Tensor, CommunicationHistory, Options)
all_gather(Tensor, Options)
all_gather(Tensor, CommGroup, Options)
all_reduce(Tensor, ReduceOp)
all_reduce(Tensor, ReduceOp, Options)
batched_communication(CommunicationOps, MaxBatchSize)
batched_communication(CommunicationOps, MaxBatchSize, Options)
broadcast(Tensor, Root, CommGroup)
broadcast(Tensor, Root, CommGroup, Options)
butterfly_all_reduce(Tensor, ReduceOp)
butterfly_all_reduce(Tensor, ReduceOp, Options)
checkpoint_communication(CommState, CheckpointPath)
communication_profiling(CommunicationOp, ProfilingOptions)
compressed_all_reduce(Tensor, ReduceOp, CompressionConfig)
compressed_all_reduce(Tensor, ReduceOp, CompressionConfig, Options)
create_communication_group(Ranks, Backend)
destroy_communication_group(GroupId)
device_aware_communication(Tensor, CommunicationOp, DeviceTopology)
device_aware_communication(Tensor, CommunicationOp, DeviceTopology, Options)
distributed_adam(Gradients, LearningRate, AdamConfig)
distributed_adamw(Gradients, LearningRate, AdamWConfig)
distributed_sgd(Gradients, LearningRate, Momentum)
dynamic_data_parallel(Model, DataLoader)
dynamic_data_parallel(Model, DataLoader, Options)
expert_parallel(Model, InputData, ExpertConfig)
expert_parallel(Model, InputData, ExpertConfig, Options)
fault_tolerant_all_reduce(Tensor, ReduceOp, FaultConfig)
fault_tolerant_all_reduce(Tensor, ReduceOp, FaultConfig, Options)
federated_averaging(LocalModels, AggregationWeights, FedConfig)
gradient_compression_ddp(Model, CompressionConfig, DDPConfig)
heterogeneous_all_reduce(Tensor, ReduceOp, DeviceConfig)
heterogeneous_all_reduce(Tensor, ReduceOp, DeviceConfig, Options)
hierarchical_all_reduce(Tensor, ReduceOp, CommGroup)
hierarchical_all_reduce(Tensor, ReduceOp, CommGroup, Options)
init_gloo_backend(GlooConfig)
init_mpi_backend(MPIConfig)
init_nccl_backend(NCCLConfig)
local_sgd(Gradients, LearningRate, Momentum, LocalSteps)
optimize_communication_plan(CommunicationOps, NetworkTopology)
overlapped_communication(ComputeFun, CommunicationOps, Tensor)
overlapped_communication(ComputeFun, CommunicationOps, Tensor, Options)
pipeline_parallel(Model, InputData, PipelineConfig)
pipeline_parallel(Model, InputData, PipelineConfig, Options)
pipelined_communication(TensorList, CommunicationOp, PipelineStages)
pipelined_communication(TensorList, CommunicationOp, PipelineStages, Options)
quantized_all_reduce(Tensor, ReduceOp, QuantizationConfig)
quantized_all_reduce(Tensor, ReduceOp, QuantizationConfig, Options)
recover_communication(CheckpointPath, RecoveryOptions)
reduce_scatter(Tensor, ReduceOp, CommGroup)
reduce_scatter(Tensor, ReduceOp, CommGroup, Options)
ring_all_reduce(Tensor, ReduceOp)
ring_all_reduce(Tensor, ReduceOp, Options)
sparse_all_reduce(SparseTensor, ReduceOp)
sparse_all_reduce(SparseTensor, ReduceOp, Options)
tensor_parallel(Model, InputData, TensorParallelConfig)
tensor_parallel(Model, InputData, TensorParallelConfig, Options)
torus_all_reduce(Tensor, ReduceOp, TorusTopology)
torus_all_reduce(Tensor, ReduceOp, TorusTopology, Options)
tree_all_reduce(Tensor, ReduceOp)
tree_all_reduce(Tensor, ReduceOp, Options)
zero_redundancy_optimizer(Optimizer, OptimizerConfig, ZeROConfig)
Functions
adaptive_communication(Tensor, CommunicationHistory)
adaptive_communication(Tensor, CommunicationHistory, Options)
all_gather(Tensor, Options)
all_gather(Tensor, CommGroup, Options)
all_reduce(Tensor, ReduceOp)
all_reduce(Tensor, ReduceOp, Options)
batched_communication(CommunicationOps, MaxBatchSize)
batched_communication(CommunicationOps, MaxBatchSize, Options)
broadcast(Tensor, Root, CommGroup)
broadcast(Tensor, Root, CommGroup, Options)
butterfly_all_reduce(Tensor, ReduceOp)
butterfly_all_reduce(Tensor, ReduceOp, Options)
checkpoint_communication(CommState, CheckpointPath)
communication_profiling(CommunicationOp, ProfilingOptions)
compressed_all_reduce(Tensor, ReduceOp, CompressionConfig)
compressed_all_reduce(Tensor, ReduceOp, CompressionConfig, Options)
create_communication_group(Ranks, Backend)
destroy_communication_group(GroupId)
device_aware_communication(Tensor, CommunicationOp, DeviceTopology)
device_aware_communication(Tensor, CommunicationOp, DeviceTopology, Options)
distributed_adam(Gradients, LearningRate, AdamConfig)
distributed_adamw(Gradients, LearningRate, AdamWConfig)
distributed_sgd(Gradients, LearningRate, Momentum)
dynamic_data_parallel(Model, DataLoader)
dynamic_data_parallel(Model, DataLoader, Options)
expert_parallel(Model, InputData, ExpertConfig)
expert_parallel(Model, InputData, ExpertConfig, Options)
fault_tolerant_all_reduce(Tensor, ReduceOp, FaultConfig)
fault_tolerant_all_reduce(Tensor, ReduceOp, FaultConfig, Options)
federated_averaging(LocalModels, AggregationWeights, FedConfig)
gradient_compression_ddp(Model, CompressionConfig, DDPConfig)
heterogeneous_all_reduce(Tensor, ReduceOp, DeviceConfig)
heterogeneous_all_reduce(Tensor, ReduceOp, DeviceConfig, Options)
hierarchical_all_reduce(Tensor, ReduceOp, CommGroup)
hierarchical_all_reduce(Tensor, ReduceOp, CommGroup, Options)
init_gloo_backend(GlooConfig)
init_mpi_backend(MPIConfig)
init_nccl_backend(NCCLConfig)
local_sgd(Gradients, LearningRate, Momentum, LocalSteps)
optimize_communication_plan(CommunicationOps, NetworkTopology)
overlapped_communication(ComputeFun, CommunicationOps, Tensor)
overlapped_communication(ComputeFun, CommunicationOps, Tensor, Options)
pipeline_parallel(Model, InputData, PipelineConfig)
pipeline_parallel(Model, InputData, PipelineConfig, Options)
pipelined_communication(TensorList, CommunicationOp, PipelineStages)
pipelined_communication(TensorList, CommunicationOp, PipelineStages, Options)
quantized_all_reduce(Tensor, ReduceOp, QuantizationConfig)
quantized_all_reduce(Tensor, ReduceOp, QuantizationConfig, Options)
recover_communication(CheckpointPath, RecoveryOptions)
reduce_scatter(Tensor, ReduceOp, CommGroup)
reduce_scatter(Tensor, ReduceOp, CommGroup, Options)
ring_all_reduce(Tensor, ReduceOp)
ring_all_reduce(Tensor, ReduceOp, Options)
sparse_all_reduce(SparseTensor, ReduceOp)
sparse_all_reduce(SparseTensor, ReduceOp, Options)
tensor_parallel(Model, InputData, TensorParallelConfig)
tensor_parallel(Model, InputData, TensorParallelConfig, Options)
torus_all_reduce(Tensor, ReduceOp, TorusTopology)
torus_all_reduce(Tensor, ReduceOp, TorusTopology, Options)
tree_all_reduce(Tensor, ReduceOp)
tree_all_reduce(Tensor, ReduceOp, Options)
zero_redundancy_optimizer(Optimizer, OptimizerConfig, ZeROConfig)