bagua.torch_api.algorithms.gradient_allreduce¶
Module Contents¶
- class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithm(hierarchical=False, average=True)¶
Bases:
bagua.torch_api.algorithms.base.Algorithm
Create an instance of the GradientAllReduce algorithm.
- Parameters
hierarchical (bool) – Enable hierarchical communication.
average (bool) – If
True
, the gradients on each worker are averaged. Otherwise, they are summed.
- class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithmImpl(process_group, hierarchical=False, average=True)¶
Bases:
bagua.torch_api.algorithms.base.AlgorithmImpl
Implementation of the GradientAllReduce algorithm.
- Parameters
process_group (BaguaProcessGroup) – The process group to work on.
hierarchical (bool) – Enable hierarchical communication.
average (bool) – If
True
, the gradients on each worker are averaged. Otherwise, they are summed.