bagua.torch_api.algorithms.gradient_allreduce¶
Module Contents¶
- class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithm(hierarchical=False, average=True)¶
Bases:
bagua.torch_api.algorithms.base.AlgorithmCreate an instance of the GradientAllReduce algorithm.
- Parameters:
hierarchical (bool) – Enable hierarchical communication.
average (bool) – If
True, the gradients on each worker are averaged. Otherwise, they are summed.
- class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithmImpl(process_group, hierarchical=False, average=True)¶
Bases:
bagua.torch_api.algorithms.base.AlgorithmImplImplementation of the GradientAllReduce algorithm.
- Parameters:
process_group (BaguaProcessGroup) – The process group to work on.
hierarchical (bool) – Enable hierarchical communication.
average (bool) – If
True, the gradients on each worker are averaged. Otherwise, they are summed.