bagua.torch_api.distributed¶
Module Contents¶
- class bagua.torch_api.distributed.BaguaModule¶
This class patches torch.nn.Module with several methods to enable Bagua functionalities.
- Variables
bagua_optimizers (List[torch.optim.Optimizer]) – The optimizers passed in by
with_bagua
.bagua_algorithm (bagua.torch_api.algorithms.Algorithm) – The algorithm passed in by
with_bagua
.parameters_to_ignore (List[str]) – The parameter names in
"{module_name}.{param_name}"
format to ignore when callingself.bagua_build_params()
.bagua_train_step_counter (int) – Number of iterations in training mode.
bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.
- bagua_build_params(self)¶
Build tuple of
(parameter_name, parameter)
for all parameters that require grads and not in the_bagua_params_and_buffers_to_ignore
attribute.- Return type
List[Tuple[str, torch.nn.Parameter]]
- with_bagua(self, optimizers, algorithm, process_group=None)¶
with_bagua
enables easy distributed data parallel training on a torch.nn.Module.- Parameters
optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.
algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.
process_group (Optional[bagua.torch_api.communication.BaguaProcessGroup]) – The process group to be used for distributed data all-reduction. If
None
, the default process group, which is created bybagua.torch_api.init_process_group
, will be used. (default:None
)
- Returns
The original module, with Bagua related environments initialized.
- Return type
Note
If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s
state_dict
(they are in"{module_name}.{param_name}"
format), then assign the list of keys toyour_module._bagua_params_and_buffers_to_ignore
.Examples:
>>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model = model.with_bagua( ... [optimizer], ... GradientAllReduce() ... )