Bagua¶
This is Bagua API documentation. See tutorials if you need step by step instructions on how to use Bagua.
bagua¶
Bagua is a communication library developed by Kuaishou Technology and DS3 Lab for deep learning.
See tutorials for Bagua’s rationale and benchmark.
Subpackages¶
bagua.torch_api¶
The Bagua communication library PyTorch interface.
Subpackages¶
bagua.torch_api.contrib¶
- class bagua.torch_api.contrib.fused_optimizer.FusedOptimizer(optimizer, do_flatten=False)¶
Bases:
torch.optim.Optimizer
Convert any optimizer into a fused optimizer.
This fused optimizer fuses multiple module parameter update kernel launches into one or a few, by flattening parameter tensors into one or more contiguous buckets.
It can be used in conjunction with
bagua.torch_api.bagua_init
. In this case, Bagua will do the fusions automatically, otherwise, you need to explicitly passdo_flatten=True
.- Parameters
optimizer (torch.optim.Optimizer) – Any PyTorch optimizer.
do_flatten (bool) – Whether to flatten the parameters. Default:
False
.
- Returns
Fused optimizer.
- Example::
To use in conjunction with
bagua.torch_api.bagua_init
:>>> optimizer = torch.optim.Adadelta(model.parameters(), ....) >>> optimizer = bagua.torch_api.FusedOptimizer(optimizer) >>> model, optimizer = bagua.bagua_init(model, optimizer, ...)
To use alone or with
torch.nn.parallel.DistributedDataParallel
, set do_flatten to beTrue
:>>> optimizer = torch.optim.Adadelta(model.parameters(), ....) >>> optimizer = bagua.torch_api.FusedOptimizer(optimizer, do_flatten=True)
- step(self, closure=None)¶
Performs a single optimization step (parameter update).
- Parameters
closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.
- class bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedSampler(dataset, complexity_fn, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, random_level=0)¶
Bases:
torch.utils.data.sampler.Sampler
Sampler that restricts data loading to a subset of the dataset.
This sampler use a complexity_fn to calculate each sample’s computational complexity and make each batch get similar computation complexity.
This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from straggler problem.
The usage is similar to
torch.utils.data.DistributedSampler
, where each process loads a subset of the original dataset that is exclusive to it.Note
Dataset is assumed to be of constant size.
- Parameters
dataset – Dataset used for sampling.
complexity_fn (Callable) – A function whose input is a sample and output is an integer as a measure of the computational complexity of the sample.
num_replicas (int, optional) – Number of processes participating in distributed training. By default,
world_size
is retrieved from the current distributed group.rank (int, optional) – Rank of the current process within
num_replicas
. By default,rank
is retrieved from the current distributed group.shuffle (bool, optional) – If
True
(default), sampler will shuffle the indices.seed (int, optional) – random seed used to shuffle the sampler if
shuffle=True
. This number should be identical across all processes in the distributed group. Default:0
.drop_last (bool, optional) – if
True
, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. IfFalse
, the sampler will add extra indices to make the data evenly divisible across the replicas. Default:False
.random_level (float, optional) – A float varies from
0
and1
that controls the extent of load balance.0
means the best load balance, while1
means the opposite.
Warning
In distributed mode, calling the
set_epoch
method at the beginning of each epoch before creating theDataLoader
iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.- Example::
Define your complexity_fn, which accepts a dataset sample as its input and produces an integer as the sample’s computational complexity.
>>> dataset = torch.utils.data.TensorDataset(torch.randn(n, 2), torch.randperm(n)) >>> complexity_fn = lambda x: x[1]
Below is the usage of
LoadBalancingDistributedSampler
andDataLoader
:>>> sampler = bagua.torch_api.contrib.data.LoadBalancingDistributedSampler( ... dataset, ... complexity_fn=complexity_fn) if is_distributed else None >>> loader = torch.utils.data.DataLoader(dataset, ... shuffle=(sampler is None), ... sampler=sampler) >>> >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)
- shuffle_chunks(self)¶
- set_epoch(self, epoch)¶
Sets the epoch for this sampler. When
shuffle=True
, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.- Parameters
epoch (int) – Epoch number.
- Return type
None
- class bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedBatchSampler(sampler, batch_fn, drop_last=False)¶
Bases:
torch.utils.data.sampler.Sampler
Wraps another load balance sampler to yield variable sized mini-batches.
- Parameters
sampler (LoadBalancingDistributedSampler) – Load balance sampler.
batch_fn (Callable) – Callable to yield mini-batch indices.
drop_last (bool) – If
True
, the sampler will drop the last few batches exceeding the least number of batches among replicas, otherwises, the number of batches on each replica will be padded to the same.
batch_fn will have the signature of
def batch_fn(indices: List[int]) -> List[List[int]]
.Example:
>>> from bagua.torch_api.contrib.data import LoadBalancingDistributedSampler, \ ... LoadBalancingDistributedBatchSampler >>> >>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn) >>> batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn) >>> loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler) >>> >>> for epoch in range(start_epoch, n_epochs): ... batch_sampler.set_epoch(epoch) ... train(loader)
- generate_batches(self)¶
- set_epoch(self, epoch)¶
Sets the epoch for this sampler. When
shuffle=True
, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.- Parameters
epoch (int) – Epoch number.
- Return type
None
Submodules¶
bagua.torch_api.communication¶
- bagua.torch_api.communication.init_process_group()¶
Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of bagua.
- Raises
RepeatedInitializationError – If you run this function repeatedly
- Examples::
>>> import bagua.torch_api as bagua >>> bagua.init_process_group() >>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model, optimizer = bagua_init(model, optimizer)
- bagua.torch_api.communication.broadcast(tensor, root=0, comm=None)¶
Broadcasts the tensor to the whole communicator.
tensor must have the same number of elements in all processes participating in the collective.
- Parameters
tensor (torch.Tensor) – Data to be sent if root is the rank of current process, and tensor to be used to save received data otherwise.
root (int, optional) – Source rank. Defaults to 0.
comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None, the global bagua communicator will be used. Defaults to None.
- bagua.torch_api.communication.allreduce(tensor, average=True, comm=None)¶
Reduces the tensor data across all machines in such a way that all get the final result. After the call tensor is going to be bitwise identical in all processes.
- Parameters
tensor (torch.Tensor) – Input and output of the collective. The function operates in-place.
average (bool, optional) – Average the reduced tensor or not, Defaults to True.
comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None the global bagua communicator will be used. Defaults to None.
Examples
>>> from bagua.torch_api import allreduce >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> tensor tensor([1, 2]) # Rank 0 tensor([3, 4]) # Rank 1 >>> allreduce(tensor) >>> tensor tensor([4, 6]) # Rank 0 tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j]) # Rank 0 tensor([3.+3.j, 4.+4.j]) # Rank 1 >>> allreduce(tensor) >>> tensor tensor([4.+4.j, 6.+6.j]) # Rank 0 tensor([4.+4.j, 6.+6.j]) # Rank 1
bagua.torch_api.distributed¶
- class bagua.torch_api.distributed.DistributedModule(module)¶
Bases:
torch.nn.Module
A base class for distributed module.
- unwrap(self)¶
Return the unwraped module.
- forward(self, *inputs, **kwargs)¶
Execute the forward process and return the output.
- class bagua.torch_api.distributed.Reducer(module, optimizers, bucket_type, hierarchical_reduce, align_bytes, chunking, fusion, decentralize_reduce=False, buckets=[], **kwargs)¶
Bases:
object
In order to improve communication efficiency, the distributed algorithm chunks parameters into many buckets. A bucket is the minimum unit of communication between devices in bagua. This module is the bucket manager, providing bucket operation methods.
The process mainly consists the following two cases:
- bucket_initialized is False:
1.1 add_param
1.2 initialize_buckets -> register_models
1.3 mark_bucket_ready
1.4 mark_on_complete
- bucket_initialized is True:
2.1 mark_tensor_ready
2.2 mark_on_complete
- Parameters
module (DistributedModule) – Module to be parallelized.
optimizers (torch.optim.Optimizer or list of torch.optim.Optimizer) – Optimizer(s) for the module. It can contain one or more PyTorch optimizers.
bucket_type (BucketType) – Type of elements in a communication bucket, could be either module parameters, weights or gradients.
hierarchical_reduce (bool) – Enable hierarchical reduce, which will perform an intra-node allreduce, followed by an inter-node reduce defined by different module, and an intra-node broadcast at the end.
align_bytes (bool) – Number to bytes to be aligned for each communication bucket.
chunking (bool) – For alltoall communication pattern, set chunking to True.
fusion (bool) – To reset parameter data pointer so that they can use faster code paths, set fusion to True.
decentralize_reduce (bool) – Whether execute the decentralize communication. Default: False.
buckets (List[List[TensorDeclaration]]) – Parameter buckets.
- fill_slot(self, param)¶
Get the value of parameters.
- initialize_buckets(self)¶
Initialize parameter buckets.
Note
Initialize_buckets MUST execute after the first round of backward.
- Returns
parameter buckets.
- Return type
List[List[torch.Tensor]]
- register_bagua_buckets(self)¶
Register bagua buckets.
- add_param(self, param)¶
Add parameter into tensor_list.
- mark_bucket_ready(self, bucket, bucket_idx)¶
Mark all tensors in the bucket ready.
- mark_tensor_ready(self, param)¶
Mark the tensor ready when got its gradient.
- mark_on_complete(self)¶
Mark all buckets have finished thier reduce process.
- class bagua.torch_api.distributed.OverlappingWrapper(module, optimizers, delay_reduce=False, bucket_type=BucketType.Gradient, hierarchical_reduce=False, decentralize_reduce=False, parameter_manager=None, align_bytes=8, chunking=False, fusion=True, **kwargs)¶
Bases:
torch.nn.Module
This class defines the process of communication-computation overlap.
- Parameters
module (torch.nn.Module) – A distributed module to be overlapped.
optimizers (torch.optim.Optimizer or list of torch.optim.Optimizer) – Optimizer(s) for the module. It can contain one or more PyTorch optimizers.
delay_reduce (bool) – Delay all communication to the end of the backward pass. This disables overlapping communication with computation. Default value is False.
bucket_type (BucketType) – Type of elements in a communication bucket, could be either module parameters, weights or gradients.
hierarchical_reduce (bool) – Enable hierarchical reduce, which will perform an intra-node allreduce, followed by an inter-node reduce defined by different module, and an intra-node broadcast at the end.
decentralize_reduce (bool) – For decentralize training, set decentralize_reduce to True.
align_bytes (int) – Number to bytes to be aligned for each communication bucket.
chunking (bool) – For alltoall communication pattern, set chunking to True.
fusion (bool) – To reset parameter data pointer so that they can use faster code paths, set fusion to True.
Note
This implementation benefits a lot from apex.parallel.DistributedDataParallel.
- reset_reducer(self, hierarchical_reduce=None, buckets=None)¶
Reset the parameter reducer.
- Parameters
hierarchical_reduce (bool) – Enable hierarchical reduce.
buckets (List[List[TensorDeclaration]]) – Parameter buckets.
- create_hooks(self)¶
Defines a number of hooks used to reduce communication buckets in backward process.
- forward(self, *inputs, **kwargs)¶
Overwrite the forward process for a distributed module with communication-computation overlap.
- class bagua.torch_api.distributed.ModelSwitchWrapper(module, optimizer, broadcast_buffers=True, delay_reduce=False, hierarchical_reduce=None, message_size=10000000, intra_comm_root_rank=0, **kwargs)¶
Bases:
torch.nn.Module
ModelSwitchWrapper is designed to switch distributed algorithms during training process. It mainly has two functions. The first is transform the original module to a distributed module. Second, this class can change the distributed mode to another one in the training process. :param module: Network definition to be run
in multi-gpu/distributed mode.
- Parameters
optimizer (torch.optim.Optimizer or list of torch.optim.Optimizer) – Optimizer(s) for the module. It can contain one or more PyTorch optimizers.
broadcast_buffers (bool) – Flag that enables syncing (broadcasting) buffers of the module at the first iteration of the forward function. Default: True.
delay_reduce (bool) – Overlap communication with computation. Default: True.
hierarchical_reduce (bool) – Enable hierarchical reduce. For GradientAllReduce algorithm, default value is False, otherwise, default value is True.
message_size (int) – Minimum bytes in a communication bucket. Default: 10_000_000.
intra_comm_root_rank (int) – Root rank of intra communication. Default: 0.
- Returns
Distributed module.
Examples:
>>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model = ModelSwitchWrapper( ... model = model, ... optimizer = optimizer, ... broadcast_buffers = broadcast_buffers, ... delay_reduce = delay_reduce, ... hierarchical_reduce = hierarchical_reduce, ... message_size = message_size, ... **kwargs, ... ).switch_to(DistributedAlgorithm.GradientAllReduce) >>> train A epochs >>> model.switch_to(DistributedAlgorithm.Decentralize) >>> train B epochs >>> model.switch_to(DistributedAlgorithm. GradientAllReduce) >>> continue training >>> ...
- switch_to(self, distributed_algorithm)¶
Switch the initial module to distributed module.
- Parameters
distributed_algorithm (DistributedAlgorithm) – Distributed algorithm used to average gradients or weights across all workers. Default: DistributedAlgorithm.GradientAllReduce.
- Returns
Return the distributed module to cover the initial one.
- state_dict(self, **kwargs)¶
Fetch the module’s state_dict.
- report_metrics(self, score_record_list)¶
Logging the metrics of auto_tune algorithm.
- ask_and_update_hyperparameters(self)¶
Execute the environment search process by auto_tune and update the hyper-parameters.
- Return type
bool
- forward(self, *inputs, **kwargs)¶
Overwrite the forward processs and return the output.
- bagua.torch_api.distributed.broadcast_parameters(module, broadcast_buffers=True)¶
Broadcast the parameters (and buffers) for synchronization in the beginning. If broadcast_buffers is False, the buffers won’t be synchronized (broadcasted) in the beginning.
- bagua.torch_api.distributed.allreduce_parameters(module)¶
Allreduce the parameters and buffers for synchronization at each time of switching distributed algorithms.
- bagua.torch_api.distributed.bagua_init(module, optimizer, distributed_algorithm=DistributedAlgorithm.GradientAllReduce, broadcast_buffers=True, delay_reduce=False, hierarchical_reduce=None, message_size=10000000, **kwargs)¶
bagua_init is a module wrapper that enables easy multiprocess distributed data parallel training using different distributed algorithms.
- Parameters
module (torch.nn.Module) – Network definition to be run in multi-gpu/distributed mode.
optimizer (torch.optim.Optimizer or list of torch.optim.Optimizer) – Optimizer(s) for the module. It can contain one or more PyTorch optimizers.
distributed_algorithm (DistributedAlgorithm) – Distributed algorithm used to average gradients or weights across all workers. Default: DistributedAlgorithm.GradientAllReduce.
broadcast_buffers (bool) – Flag that enables syncing (broadcasting) buffers of the module at the first iteration of the forward function. Default: True.
delay_reduce (bool) – Delay all communication to the end of the backward pass. This disables overlapping communication with computation. Default value is False.
hierarchical_reduce (bool) – Enable hierarchical reduce. For GradientAllReduce algorithm, default value is False, otherwise, default value is True.
message_size (int) – Minimum bytes in a communication bucket. Default: 10_000_000.
- Returns
Distributed module.
Examples:
>>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model, optimizer = bagua_init( ... model, ... optimizer, ... broadcast_buffers=True ... )
bagua.torch_api.env¶
- bagua.torch_api.env.get_world_size()¶
Get the number of processes in the current process group.
- Returns
The world size of the process group.
- bagua.torch_api.env.get_rank()¶
Get the rank of current process group.
Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to
world_size
.- Returns
The rank of the process group.
- bagua.torch_api.env.get_local_rank()¶
Get the rank of current node.
Local rank is a unique identifier assigned to each process within a node. They are always consecutive integers ranging from 0 to
local_size
.- Returns
The local rank of the node.
- bagua.torch_api.env.get_local_size()¶
Get the number of processes in the node.
- Returns
The local size of the node.
- bagua.torch_api.env.get_default_bucket_size()¶
Get default communication bucket byte size.
- Returns
default bucket size
- Return type
int
Package Contents¶
- bagua.torch_api.init_process_group()¶
Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of bagua.
- Raises
RepeatedInitializationError – If you run this function repeatedly
- Examples::
>>> import bagua.torch_api as bagua >>> bagua.init_process_group() >>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model, optimizer = bagua_init(model, optimizer)
- bagua.torch_api.allreduce(tensor, average=True, comm=None)¶
Reduces the tensor data across all machines in such a way that all get the final result. After the call tensor is going to be bitwise identical in all processes.
- Parameters
tensor (torch.Tensor) – Input and output of the collective. The function operates in-place.
average (bool, optional) – Average the reduced tensor or not, Defaults to True.
comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None the global bagua communicator will be used. Defaults to None.
Examples
>>> from bagua.torch_api import allreduce >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> tensor tensor([1, 2]) # Rank 0 tensor([3, 4]) # Rank 1 >>> allreduce(tensor) >>> tensor tensor([4, 6]) # Rank 0 tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j]) # Rank 0 tensor([3.+3.j, 4.+4.j]) # Rank 1 >>> allreduce(tensor) >>> tensor tensor([4.+4.j, 6.+6.j]) # Rank 0 tensor([4.+4.j, 6.+6.j]) # Rank 1
- bagua.torch_api.broadcast(tensor, root=0, comm=None)¶
Broadcasts the tensor to the whole communicator.
tensor must have the same number of elements in all processes participating in the collective.
- Parameters
tensor (torch.Tensor) – Data to be sent if root is the rank of current process, and tensor to be used to save received data otherwise.
root (int, optional) – Source rank. Defaults to 0.
comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None, the global bagua communicator will be used. Defaults to None.
- bagua.torch_api.bagua_init(module, optimizer, distributed_algorithm=DistributedAlgorithm.GradientAllReduce, broadcast_buffers=True, delay_reduce=False, hierarchical_reduce=None, message_size=10000000, **kwargs)¶
bagua_init is a module wrapper that enables easy multiprocess distributed data parallel training using different distributed algorithms.
- Parameters
module (torch.nn.Module) – Network definition to be run in multi-gpu/distributed mode.
optimizer (torch.optim.Optimizer or list of torch.optim.Optimizer) – Optimizer(s) for the module. It can contain one or more PyTorch optimizers.
distributed_algorithm (DistributedAlgorithm) – Distributed algorithm used to average gradients or weights across all workers. Default: DistributedAlgorithm.GradientAllReduce.
broadcast_buffers (bool) – Flag that enables syncing (broadcasting) buffers of the module at the first iteration of the forward function. Default: True.
delay_reduce (bool) – Delay all communication to the end of the backward pass. This disables overlapping communication with computation. Default value is False.
hierarchical_reduce (bool) – Enable hierarchical reduce. For GradientAllReduce algorithm, default value is False, otherwise, default value is True.
message_size (int) – Minimum bytes in a communication bucket. Default: 10_000_000.
- Returns
Distributed module.
Examples:
>>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model, optimizer = bagua_init( ... model, ... optimizer, ... broadcast_buffers=True ... )
- bagua.torch_api.get_rank()¶
Get the rank of current process group.
Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to
world_size
.- Returns
The rank of the process group.
- bagua.torch_api.get_world_size()¶
Get the number of processes in the current process group.
- Returns
The world size of the process group.
- bagua.torch_api.get_local_rank()¶
Get the rank of current node.
Local rank is a unique identifier assigned to each process within a node. They are always consecutive integers ranging from 0 to
local_size
.- Returns
The local rank of the node.
- bagua.torch_api.get_local_size()¶
Get the number of processes in the node.
- Returns
The local size of the node.
- class bagua.torch_api.DistributedAlgorithm¶
Bases:
enum.Enum
An enum-like class of available distributed algorithms: allreduce, sg-allreduce, quantize and decentralize.
The values of this class are lowercase strings, e.g.,
"allreduce"
. They can be accessed as attributes, e.g.,DistributedAlgorithm.GradientAllReduce
.This class can be directly called to parse the string, e.g.,
DistributedAlgorithm(algor_str)
.- GradientAllReduce = allreduce¶
- ScatterGatherAllReduce = sg-allreduce¶
- Decentralize = decentralize¶
- QuantizeAllReduce = quantize¶
- static from_str(val)¶
- Parameters
val (str) –
Submodules¶
bagua.bagua_define¶
Module Contents¶
- class bagua.bagua_define.DistributedAlgorithm¶
Bases:
enum.Enum
An enum-like class of available distributed algorithms: allreduce, sg-allreduce, quantize and decentralize.
The values of this class are lowercase strings, e.g.,
"allreduce"
. They can be accessed as attributes, e.g.,DistributedAlgorithm.GradientAllReduce
.This class can be directly called to parse the string, e.g.,
DistributedAlgorithm(algor_str)
.- GradientAllReduce = allreduce¶
- ScatterGatherAllReduce = sg-allreduce¶
- Decentralize = decentralize¶
- QuantizeAllReduce = quantize¶