bagua.torch_api.algorithms.gradient_allreduce

Module Contents

class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithm(hierarchical=False, average=True)

Bases: bagua.torch_api.algorithms.base.Algorithm

Create an instance of the GradientAllReduce algorithm.

Parameters
  • hierarchical (bool) – Enable hierarchical communication.

  • average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.

class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithmImpl(process_group, hierarchical=False, average=True)

Bases: bagua.torch_api.algorithms.base.AlgorithmImpl

Implementation of the GradientAllReduce algorithm.

Parameters
  • process_group (BaguaProcessGroup) – The process group to work on.

  • hierarchical (bool) – Enable hierarchical communication.

  • average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.