bagua.torch_api.algorithms.gradient_allreduce

Module Contents

class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithm(hierarchical_reduce=False, average=True)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the GradientAllReduce algorithm.

Parameters
  • hierarchical_reduce (bool) – Enable hierarchical communication.

  • average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.

init_operations(self, bagua_module, bucket)
Parameters