- bagua.torch_api.data_parallel.distributed.DistributedDataParallel(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=True, optimizers=, algorithm=GradientAllReduceAlgorithm())¶
This function provides a PyTorch DDP compatible interface plus several Bagua specific parameters.
module (Module) – module to be parallelized
device_ids (Optional[List[Union[int, torch.device]]], optional) –
1) For single-device modules,
device_idscan contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively,
device_idscan also be
2) For multi-device modules and CPU modules,
Nonefor both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default:
output_device (Union[int, torch.device], optional) – Device location of output for single-device CUDA modules. For multi-device modules and CPU modules, it must be
None, and the module itself dictates the output location. (default:
device_idsfor single-device modules)
dim (int, optional) – Flag that enables syncing (broadcasting) buffers of the module at beginning of the
broadcast_buffers (bool, optional) – Flag that enables syncing (broadcasting) buffers of the module at beginning of the
process_group (Union[None, TorchProcessGroup], optional) – The process group to be used for distributed data all-reduction. If
None, the default process group, which is created by
torch.distributed.init_process_group, will be used. (default:
bucket_cap_mb (int, optional) –
DistributedDataParallelwill bucket parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation.
bucket_cap_mbcontrols the bucket size in MegaBytes (MB). (default: 25)
find_unused_parameters (bool, optional) – Traverse the autograd graph from all tensors contained in the return value of the wrapped module’s
forwardfunction. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready to be reduced. In addition, parameters that may have been used in the wrapped module’s
forwardfunction but were not part of loss computation and thus would also not receive gradients are preemptively marked as ready to be reduced. (default:
check_reduction (bool, optional) – This argument is deprecated.
gradient_as_bucket_view (bool, optional) – When set to
True, gradients will be views pointing to different offsets of
allreducecommunication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. Moreover, it avoids the overhead of copying between gradients and
allreducecommunication buckets. When gradients are views,
detach_()cannot be called on the gradients. If hitting such errors, please fix it by referring to the
torch/optim/optimizer.pyas a solution.
optimizers (List[torch.optim.Optimizer], optional) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers. Defaults to
algorithm (bagua.torch_api.algorithms.Algorithm, optional) – Data parallel distributed algorithm, decide how to communication mode and the way the model is updated. Defaults to
Bagua distributed data parallel instance used for distributed training.
- Return type
>>> bagua.init_process_group() >>> net = bagua.data_parallel.DistributedDataParallel(model)
Example using faster algorithms in Bagua:
>>> from bagua.torch_api.algorithms import bytegrad >>> bagua.init_process_group() >>> net = bagua.data_parallel.DistributedDataParallel(model, algorithm=bytegrad.ByteGradAlgorithm()) >>> # For more possible algorithms, see https://tutorials.baguasys.com/algorithms/.
Convert a PyTorch process group to a Bagua process group.
process_group (Union[TorchProcessGroup, BaguaProcessGroup, None], optional) – PyTorch process group or Bagua process group. The default PyTorch process group is used if
Noneis passed in.
Exception – raise unexpect input exception if input is not
process group for communication in bagua.
- Return type