bagua.torch_api.distributed

Module Contents

class bagua.torch_api.distributed.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

Variables:
  • bagua_optimizers (str) – The optimizers passed in by with_bagua.

  • bagua_algorithm (bagua.torch_api.algorithms.AlgorithmImpl) – The algorithm implementation used by the module, reified by the algorithm passed in by with_bagua.

  • process_group (bagua.torch_api.communication.BaguaProcessGroup) – The process group used by the module.

  • bagua_module_name – The module’s name. Bagua uses the module name to distinguish different modules.

  • parameters_to_ignore (List[str]) – The parameter names in "{module_name}.{param_name}" format to ignore when calling self.bagua_build_params().

  • bagua_train_step_counter (int) – Number of iterations in training mode.

  • bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.

with_bagua(optimizers, algorithm, process_group=None, do_flatten=True)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

Parameters:
  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.

  • process_group (Optional[bagua.torch_api.communication.BaguaProcessGroup]) – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by bagua.torch_api.init_process_group, will be used. (default: None)

  • do_flatten (bool) – Whether to flatten the Bagua buckets. The flatten operation will reset data pointer of bucket tensors so that they can use faster code paths. Default: True.

Returns:

The original module, with Bagua related environments initialized.

Return type:

BaguaModule

Note

If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.

Examples:

>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )