bagua.torch_api.algorithms.gradient_allreduce¶
Module Contents¶
- class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithm(hierarchical=False, average=True)¶
Bases:
bagua.torch_api.algorithms.Algorithm
Create an instance of the GradientAllReduce algorithm.
- Parameters
hierarchical (bool) – Enable hierarchical communication.
average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.
- init_operations(self, bagua_module, bucket)¶
- Parameters
bagua_module (bagua.torch_api.distributed.BaguaModule) –
bucket (bagua.torch_api.bucket.BaguaBucket) –