bagua.torch_api.distributed

Module Contents

class bagua.torch_api.distributed.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

bagua_build_params(self)

Build tuple of (parameter_name, parameter) for all parameters that require grads and not in the _bagua_params_and_buffers_to_ignore attribute.

Returns

List[(str, torch.nn.Parameter)]

with_bagua(self, optimizers, algorithm)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

Parameters
  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithm.Algorithm) – Distributed algorithm used to do the actual communication and update.

Returns

The original module, with Bagua related environments initialized.

Note

If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.

Examples:

>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )