Module Contents

bagua.torch_api.data_parallel.distributed.DistributedDataParallel(module, device_ids=None, output_device=None, dim=0, broadcast_buffers=True, process_group=None, bucket_cap_mb=25, find_unused_parameters=False, check_reduction=False, gradient_as_bucket_view=True, optimizers=[], algorithm=GradientAllReduceAlgorithm())

This function provides a PyTorch DDP compatible interface plus several Bagua specific parameters.

  • module (Module) – module to be parallelized

  • device_ids (Optional[List[Union[int, torch.device]]], optional) –

    CUDA devices.

    1) For single-device modules, device_ids can contain exactly one device id, which represents the only CUDA device where the input module corresponding to this process resides. Alternatively, device_ids can also be None.

    2) For multi-device modules and CPU modules, device_ids must be None.

    When device_ids is None for both cases, both the input data for the forward pass and the actual module must be placed on the correct device. (default: None)

  • output_device (Union[int, torch.device], optional) – Device location of output for single-device CUDA modules. For multi-device modules and CPU modules, it must be None, and the module itself dictates the output location. (default: device_ids[0] for single-device modules)

  • dim (int, optional) – Flag that enables syncing (broadcasting) buffers of the module at beginning of the forward function. (default: True)

  • broadcast_buffers (bool, optional) – Flag that enables syncing (broadcasting) buffers of the module at beginning of the forward function. (default: True)

  • process_group (Union[None, TorchProcessGroup], optional) – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by torch.distributed.init_process_group, will be used. (default: None)

  • bucket_cap_mb (int, optional) – DistributedDataParallel will bucket parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MegaBytes (MB). (default: 25)

  • find_unused_parameters (bool, optional) – Traverse the autograd graph from all tensors contained in the return value of the wrapped module’s forward function. Parameters that don’t receive gradients as part of this graph are preemptively marked as being ready to be reduced. In addition, parameters that may have been used in the wrapped module’s forward function but were not part of loss computation and thus would also not receive gradients are preemptively marked as ready to be reduced. (default: False)

  • check_reduction (bool, optional) – This argument is deprecated.

  • gradient_as_bucket_view (bool, optional) – When set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. Moreover, it avoids the overhead of copying between gradients and allreduce communication buckets. When gradients are views, detach_() cannot be called on the gradients. If hitting such errors, please fix it by referring to the zero_grad function in torch/optim/ as a solution.

  • optimizers (List[torch.optim.Optimizer], optional) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers. Defaults to [].

  • algorithm (bagua.torch_api.algorithms.Algorithm, optional) – Data parallel distributed algorithm, decide how to communication mode and the way the model is updated. Defaults to GradientAllReduceAlgorithm.


Bagua distributed data parallel instance used for distributed training.

Return type

Union[TorchDistributedDataParallel, DistributedDataParallel_V1_9_0]


>>> bagua.init_process_group()
>>> net = bagua.data_parallel.DistributedDataParallel(model)

Example using faster algorithms in Bagua:

>>> from bagua.torch_api.algorithms import bytegrad
>>> bagua.init_process_group()
>>> net = bagua.data_parallel.DistributedDataParallel(model, algorithm=bytegrad.ByteGradAlgorithm())
>>> # For more possible algorithms, see

Convert a PyTorch process group to a Bagua process group.


process_group (Union[TorchProcessGroup, BaguaProcessGroup, None], optional) – PyTorch process group or Bagua process group. The default PyTorch process group is used if None is passed in.


Exception – raise unexpect input exception if input is not TorchProcessGroup, BaguaProcessGroup or None.


process group for communication in bagua.

Return type