class bagua.torch_api.distributed.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

  • bagua_optimizers (List[torch.optim.Optimizer]) – The optimizers passed in by with_bagua.

  • bagua_algorithm (bagua.torch_api.algorithms.Algorithm) – The algorithm passed in by with_bagua.

  • parameters_to_ignore (List[str]) – The parameter names in "{module_name}.{param_name}" format to ignore when calling self.bagua_build_params().

  • bagua_train_step_counter (int) – Number of iterations in training mode.

  • bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.


Build tuple of (parameter_name, parameter) for all parameters that require grads and not in the _bagua_params_and_buffers_to_ignore attribute.

Return type

List[Tuple[str, torch.nn.Parameter]]

with_bagua(self, optimizers, algorithm, process_group=None)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.

  • process_group (Optional[bagua.torch_api.communication.BaguaProcessGroup]) – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by bagua.torch_api.init_process_group, will be used. (default: None)


The original module, with Bagua related environments initialized.

Return type



If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.


>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )