Module Contents

class bagua.torch_api.distributed.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

  • bagua_optimizers (str) – The optimizers passed in by with_bagua.

  • bagua_algorithm (bagua.torch_api.algorithms.AlgorithmImpl) – The algorithm implementation used by the module, reified by the algorithm passed in by with_bagua.

  • process_group (bagua.torch_api.communication.BaguaProcessGroup) – The process group used by the module.

  • bagua_module_name – The module’s name. Bagua uses the module name to distinguish different modules.

  • parameters_to_ignore (List[str]) – The parameter names in "{module_name}.{param_name}" format to ignore when calling self.bagua_build_params().

  • bagua_train_step_counter (int) – Number of iterations in training mode.

  • bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.

with_bagua(optimizers, algorithm, process_group=None, do_flatten=True)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.

  • process_group (Optional[bagua.torch_api.communication.BaguaProcessGroup]) – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by bagua.torch_api.init_process_group, will be used. (default: None)

  • do_flatten (bool) – Whether to flatten the Bagua buckets. The flatten operation will reset data pointer of bucket tensors so that they can use faster code paths. Default: True.


The original module, with Bagua related environments initialized.

Return type



If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.


>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )