bagua.torch_api.distributed¶
Module Contents¶
- class bagua.torch_api.distributed.BaguaModule¶
This class patches torch.nn.Module with several methods to enable Bagua functionalities.
- bagua_build_params(self)¶
Build tuple of (parameter_name, parameter) for all parameters that require grads and not in the
_bagua_params_and_buffers_to_ignore
attribute.- Returns
List[(str, torch.nn.Parameter)]
- with_bagua(self, optimizers, algorithm)¶
with_bagua enables easy distributed data parallel training on a torch.nn.Module.
- Parameters
optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.
algorithm (bagua.torch_api.algorithm.Algorithm) – Distributed algorithm used to do the actual communication and update.
- Returns
The original module, with Bagua related environments initialized.
Note
If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s
state_dict
(they are in"{module_name}.{param_name}"
format), then assign the list of keys toyour_module._bagua_params_and_buffers_to_ignore
.Examples:
>>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model = model.with_bagua( ... [optimizer], ... GradientAllReduce() ... )