bagua.torch_api.communication

Module Contents

class bagua.torch_api.communication.BaguaProcessGroup(ranks, stream, group_name)

Definition of Bagua process group.

get_global_communicator()

Returns the global communicator of current process group.

Return type:

bagua_core.BaguaSingleCommunicatorPy

get_inter_node_communicator()

Returns the inter-node communicator of current process group.

Return type:

bagua_core.BaguaSingleCommunicatorPy

get_intra_node_communicator()

Returns the intra-node communicator of current process group.

Return type:

bagua_core.BaguaSingleCommunicatorPy

class bagua.torch_api.communication.ReduceOp

Bases: enum.IntEnum

An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, BXOR and AVG.

Initialize self. See help(type(self)) for accurate signature.

AVG = 10
BAND = 8
BOR = 7
BXOR = 9
MAX = 3
MIN = 2
PRODUCT = 1
SUM = 0
bagua.torch_api.communication.allgather(send_tensor, recv_tensor, comm=None)

Gathers send tensors from all processes associated with the communicator into recv_tensor.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have a size of comm.nranks * send_tensor.size() elements.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.allgather_inplace(tensor, comm=None)

The in-place version of allgather.

Parameters:
  • tensor (torch.Tensor) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.communication.allreduce(send_tensor, recv_tensor, op=ReduceOp.SUM, comm=None)

Reduces the tensor data across all processes associated with the communicator in such a way that all get the final result. After the call recv_tensor is going to be bitwise identical in all processes.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have the same size with send_tensor.

  • op (ReduceOp) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

Examples:

>>> from bagua.torch_api import allreduce
>>>
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.arange(2, dtype=torch.int64, device=tensor.device) + 1 + 2 * rank
>>> recv_tensor = torch.zeros(2, dtype=torch.int64, device=tensor.device)
>>> send_tensor
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4, 6], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1

>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=tensor.device) + 2 * rank * (1+1j)
>>> recv_tensor = torch.zeros(2, dtype=torch.cfloat, device=tensor.device)
>>> send_tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0
tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1
bagua.torch_api.communication.allreduce_inplace(tensor, op=ReduceOp.SUM, comm=None)

The in-place version of allreduce.

Parameters:
  • tensor (torch.Tensor) –

  • op (ReduceOp) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.communication.alltoall(send_tensor, recv_tensor, comm=None)

Each process scatters send_tensor to all processes associated with the communicator and return the gathered data in recv_tensor.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective, the size must be divisible by comm.nranks.

  • recv_tensor (torch.Tensor) – Output of the collective, must have equal size with send_tensor.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.alltoall_inplace(tensor, comm=None)

The in-place version of alltoall.

Parameters:
  • tensor (torch.Tensor) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.communication.barrier(comm=None)

Synchronizes all processes. This collective blocks processes until all processes associated with the communicator enters this function.

Parameters:

comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.broadcast(tensor, src=0, comm=None)

Broadcasts the tensor to all processes associated with the communicator.

tensor must have the same number of elements in all processes participating in the collective.

Parameters:
  • tensor (torch.Tensor) – Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise.

  • src (int) – Source rank. Default: 0.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.broadcast_object(obj, src=0, comm=None)

Serializes and broadcasts an object from root rank to all other processes. Typical usage is to broadcast the optimizer.state_dict(), for example:

>>> state_dict = broadcast_object(optimizer.state_dict(), 0)
>>> if get_rank() > 0:
>>>     optimizer.load_state_dict(state_dict)
Parameters:
  • obj (object) – An object capable of being serialized without losing any context.

  • src (int) – The rank of the process from which parameters will be broadcasted to all other processes.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

Returns:

The object that was broadcasted from the src.

Return type:

object

Note

This operation will move data to GPU before communication and back to CPU after communication, and it requires CPU-GPU synchronization.

bagua.torch_api.communication.from_torch_group(group, stream=None)

Convert a Pytorch process group to its equivalent Bagua process group.

Parameters:
  • group – A handle of the Pytorch process group.

  • stream (Optional[torch.cuda.Stream]) – A CUDA stream used to execute NCCL operations. If None, CUDA stream of the default group will be used. See new_group for more information.

Returns:

A handle of the Bagua process group.

Return type:

BaguaProcessGroup

bagua.torch_api.communication.gather(send_tensor, recv_tensor, dst, comm=None)

Gathers send tensors from all processes associated with the communicator to recv_tensor in a single process.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have a size of comm.nranks * send_tensor.size() elements.

  • dst (int) – Destination rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.gather_inplace(tensor, count, dst, comm=None)

The in-place version of gather.

Parameters:
  • tensor (torch.Tensor) – Input and output of the collective, On the dst rank, it must have a size of comm.nranks * count elements. On non-dst ranks, its size must be equal to :attr:count.

  • count (int) – The per-rank data count to gather.

  • dst (int) – Destination rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.init_process_group(store=None)

Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of Bagua.

Parameters:

store (Optional[torch.distributed.Store]) – Key/value store accessible to all workers, used to exchange connection/address information. If None, a TCP-based store will be created. Default: None.

Examples::
>>> import torch
>>> import bagua.torch_api as bagua
>>>
>>> torch.cuda.set_device(bagua.get_local_rank()) # THIS LINE IS IMPORTANT. See the notes below.
>>> bagua.init_process_group()
>>>
>>> model = torch.nn.Sequential(
...    torch.nn.Linear(D_in, H),
...    torch.nn.ReLU(),
...    torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...    model.parameters(),
...    lr=0.01,
...    momentum=0.9
...    )
>>> model = model.with_bagua([optimizer], ...)

Note

Each process should be associated to a CUDA device using torch.cuda.set_device(), before calling init_process_group. Otherwise you may encounter the fatal runtime error: Rust cannot catch foreign exceptions error.

bagua.torch_api.communication.is_initialized()

Checking if the default process group has been initialized.

bagua.torch_api.communication.new_group(ranks=None, stream=None)

Creates a new process group.

This function requires that all processes in the default group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes.

Each process group will create three communicators on request, a global communicator, a inter-node communicator and a intra-node communicator. Users can access them through group.get_global_communicator(), group.get_inter_node_communicator() and group.get_intra_node_communicator() respectively.

Parameters:
  • ranks (Optional[List[int]]) – List of ranks of group members. If None, will be set to all ranks. Default is None.

  • stream (Optional[torch.cuda.Stream]) – A CUDA stream used to execute NCCL operations. If None, CUDA stream of the default group will be used. See CUDA semantics for details.

Returns:

A handle of process group that can be given to collective calls.

Return type:

BaguaProcessGroup

Note

The global communicator is used for global communications involving all ranks in the process group. The inter-node communicator and the intra-node communicator is used for hierarchical communications in this process group.

Note

For a specific communicator comm, comm.rank() returns the rank of current process and comm.nranks() returns the size of the communicator.

bagua.torch_api.communication.recv(tensor, src, comm=None)

Receives a tensor synchronously.

Parameters:
  • tensor (torch.Tensor) – Tensor to fill with received data.

  • src (int) – Source rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.reduce(send_tensor, recv_tensor, dst, op=ReduceOp.SUM, comm=None)

Reduces the tensor data across all processes.

Only the process whit rank dst is going to receive the final result.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have the same size with send_tensor.

  • dst (int) – Destination rank.

  • op (ReduceOp) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.reduce_inplace(tensor, dst, op=ReduceOp.SUM, comm=None)

The in-place version of reduce.

Parameters:
  • tensor (torch.Tensor) –

  • dst (int) –

  • op (ReduceOp) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.communication.reduce_scatter(send_tensor, recv_tensor, op=ReduceOp.SUM, comm=None)

Reduces, then scatters send_tensor to all processes associated with the communicator.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective, must have a size of comm.nranks * recv_tensor.size() elements.

  • recv_tensor (torch.Tensor) – Output of the collective.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.reduce_scatter_inplace(tensor, op=ReduceOp.SUM, comm=None)

The in-place version of reduce_scatter.

Parameters:
  • tensor (torch.Tensor) – Input and output of the collective, the size must be divisible by comm.nranks.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.scatter(send_tensor, recv_tensor, src, comm=None)

Scatters send tensor to all processes associated with the communicator.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective, must have a size of comm.nranks * recv_tensor.size() elements.

  • recv_tensor (torch.Tensor) – Output of the collective.

  • src (int) – Source rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.scatter_inplace(tensor, count, src, comm=None)

The in-place version of scatter.

Parameters:
  • tensor (torch.Tensor) – Input and output of the collective, On the src rank, it must have a size of comm.nranks * count elements. On non-src ranks, its size must be equal to count.

  • count (int) – The per-rank data count to scatter.

  • src (int) – Source rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.communication.send(tensor, dst, comm=None)

Sends a tensor to dst synchronously.

Parameters:
  • tensor (torch.Tensor) – Tensor to send.

  • dst (int) – Destination rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.