bagua.torch_api.data_parallel.functional

Module Contents

class bagua.torch_api.data_parallel.functional.ReduceOp

Bases: enum.IntEnum

An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, BXOR and AVG.

Initialize self. See help(type(self)) for accurate signature.

AVG = 10
BAND = 8
BOR = 7
BXOR = 9
MAX = 3
MIN = 2
PRODUCT = 1
SUM = 0
bagua.torch_api.data_parallel.functional.all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD)

Reduces the tensor data across all machines in such a way that all get the final result.

After the call the returned tensor is going to be bitwise identical in all processes.

Parameters:
  • tensor (Tensor) – Input of the collective.

  • op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.

  • group (ProcessGroup, optional) – The process group to work on.

Returns:

Output of the collective

Return type:

Tensor