bagua.torch_api

The Bagua communication library PyTorch interface.

Subpackages

Submodules

Package Contents

class bagua.torch_api.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

Variables:
  • bagua_optimizers (str) – The optimizers passed in by with_bagua.

  • bagua_algorithm (bagua.torch_api.algorithms.AlgorithmImpl) – The algorithm implementation used by the module, reified by the algorithm passed in by with_bagua.

  • process_group (bagua.torch_api.communication.BaguaProcessGroup) – The process group used by the module.

  • bagua_module_name – The module’s name. Bagua uses the module name to distinguish different modules.

  • parameters_to_ignore (List[str]) – The parameter names in "{module_name}.{param_name}" format to ignore when calling self.bagua_build_params().

  • bagua_train_step_counter (int) – Number of iterations in training mode.

  • bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.

with_bagua(optimizers, algorithm, process_group=None, do_flatten=True)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

Parameters:
  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.

  • process_group (Optional[bagua.torch_api.communication.BaguaProcessGroup]) – The process group to be used for distributed data all-reduction. If None, the default process group, which is created by bagua.torch_api.init_process_group, will be used. (default: None)

  • do_flatten (bool) – Whether to flatten the Bagua buckets. The flatten operation will reset data pointer of bucket tensors so that they can use faster code paths. Default: True.

Returns:

The original module, with Bagua related environments initialized.

Return type:

BaguaModule

Note

If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.

Examples:

>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )
class bagua.torch_api.BaguaTensor

This class patch torch.Tensor with additional methods.

A Bagua tensor is required to use Bagua’s communication algorithms. Users can convert a PyTorch tensor to Bagua tensor by ensure_bagua_tensor.

Bagua tensor features a proxy structure, where the actual tensor used by backend is accessed via a “Proxy Tensor”. The proxy tensor is registered in Bagua, whenever the Bagua backend needs a tensor (for example use it for communication), it calls the bagua_getter_closure on the proxy tensor to get the tensor that is actually worked on. We call this tensor “Effective Tensor”. The bagua_setter_closure is also provided to replace the effective tensor during runtime. It is intended to be used to replace the effective tensor with customized workflow.

Their relation can be seen in the following diagram:

https://user-images.githubusercontent.com/18649508/139179394-51d0c0f5-e233-4ada-8e5e-0e70a889540d.png

For example, in the gradient allreduce algorithm, the effective tensor that needs to be exchanged between machines is the gradient. In this case, we will register the model parameters as proxy tensor, and register bagua_getter_closure to be lambda proxy_tensor: proxy_tensor.grad. In this way, even if the gradient tensor is recreated or changed during runtime, Bagua can still identify the correct tensor and use it for communication, since the proxy tensor serves as the root for access and is never replaced.

bagua_backend_tensor()
Returns:

The raw Bagua backend tensor.

Return type:

bagua_core.BaguaTensorPy

bagua_ensure_grad()

Create a zero gradient tensor for the current parameter if not exist.

Returns:

The original tensor.

Return type:

torch.Tensor

bagua_getter_closure()

Returns the tensor that will be used in runtime.

Return type:

torch.Tensor

bagua_mark_communication_ready()

Mark a Bagua tensor ready for scheduled operations execution.

bagua_mark_communication_ready_without_synchronization()

Mark a Bagua tensor ready immediately, without CUDA event synchronization.

bagua_set_storage(storage, storage_offset=0)

Sets the underlying storage for the effective tensor returned by bagua_getter_closure with an existing torch.Storage.

Parameters:
  • storage (torch.Storage) – The storage to use.

  • storage_offset (int) – The offset in the storage.

bagua_setter_closure(tensor)

Sets the tensor that will be used in runtime to a new Pytorch tensor tensor.

Parameters:

tensor (torch.Tensor) – The new tensor to be set to.

ensure_bagua_tensor(name=None, module_name=None, getter_closure=None, setter_closure=None)

Convert a PyTorch tensor or parameter to Bagua tensor inplace and return it. A Bagua tensor is required to use Bagua’s communication algorithms.

This operation will register self as proxy tensor to the Bagua backend. getter_closure takes the proxy tensor as input and returns a Pytorch tensor. When using the Bagua tensor, the getter_closure will be called and returns the effective tensor which will be used for communication and other operations. For example, if one of a model’s parameter param is registered as proxy tensor, and getter_closure is lambda x: x.grad, during runtime its gradient will be used.

setter_closure takes the proxy tensor and another tensor as inputs and returns nothing. It is mainly used for changing the effective tensor used in runtime. For example when one of a model’s parameter param is registered as proxy tensor, and getter_closure is lambda x: x.grad, the setter_closure can be lambda param, new_grad_tensor: setattr(param, "grad", new_grad_tensor). When the setter_closure is called, the effective tensor used in later operations will be changed to new_grad_tensor.

Parameters:
  • name (Optional[str]) – The unique name of the tensor.

  • module_name (Optional[str]) – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

  • getter_closure (Optional[Callable[[torch.Tensor], torch.Tensor]]) – A function that accepts a Pytorch tensor as its input and returns a Pytorch tensor as its output. Could be None, which means an identity mapping lambda x: x is used. Default: None.

  • setter_closure (Optional[Callable[[torch.Tensor, torch.Tensor], None]]) – A function that accepts two Pytorch tensors as its inputs and returns nothing. Could be None, which is a no-op. Default: None.

Returns:

The original tensor with Bagua tensor attributes initialized.

is_bagua_tensor()

Checking if this is a Bagua tensor.

Return type:

bool

to_bagua_tensor(name=None, module_name=None, getter_closure=None, setter_closure=None)

Create a new Bagua tensor from a PyTorch tensor or parameter and return it. The new Bagua tensor will share the same storage with the input PyTorch tensor. A Bagua tensor is required to use Bagua’s communication algorithms. See ensure_bagua_tensor for more information.

Caveat: Be aware that if the original tensor changes to use a different storage using for example torch.Tensor.set_(...), the new Bagua tensor will still use the old storage.

Parameters:
  • name (Optional[str]) – The unique name of the tensor.

  • module_name (Optional[str]) – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

  • getter_closure (Optional[Callable[[torch.Tensor], torch.Tensor]]) – A function that accepts a Pytorch tensor as its input and returns a Pytorch tensor as its output. See ensure_bagua_tensor.

  • setter_closure (Optional[Callable[[torch.Tensor, torch.Tensor], None]]) – A function that accepts two Pytorch tensors as its inputs and returns nothing. See ensure_bagua_tensor.

Returns:

The new Bagua tensor sharing the same storage with the original tensor.

class bagua.torch_api.ReduceOp

Bases: enum.IntEnum

An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, BXOR and AVG.

Initialize self. See help(type(self)) for accurate signature.

AVG = 10
BAND = 8
BOR = 7
BXOR = 9
MAX = 3
MIN = 2
PRODUCT = 1
SUM = 0
bagua.torch_api.allgather(send_tensor, recv_tensor, comm=None)

Gathers send tensors from all processes associated with the communicator into recv_tensor.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have a size of comm.nranks * send_tensor.size() elements.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.allgather_inplace(tensor, comm=None)

The in-place version of allgather.

Parameters:
  • tensor (torch.Tensor) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.allreduce(send_tensor, recv_tensor, op=ReduceOp.SUM, comm=None)

Reduces the tensor data across all processes associated with the communicator in such a way that all get the final result. After the call recv_tensor is going to be bitwise identical in all processes.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have the same size with send_tensor.

  • op (ReduceOp) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

Examples:

>>> from bagua.torch_api import allreduce
>>>
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.arange(2, dtype=torch.int64, device=tensor.device) + 1 + 2 * rank
>>> recv_tensor = torch.zeros(2, dtype=torch.int64, device=tensor.device)
>>> send_tensor
tensor([1, 2], device='cuda:0') # Rank 0
tensor([3, 4], device='cuda:1') # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4, 6], device='cuda:0') # Rank 0
tensor([4, 6], device='cuda:1') # Rank 1

>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat, device=tensor.device) + 2 * rank * (1+1j)
>>> recv_tensor = torch.zeros(2, dtype=torch.cfloat, device=tensor.device)
>>> send_tensor
tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0
tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1
bagua.torch_api.allreduce_inplace(tensor, op=ReduceOp.SUM, comm=None)

The in-place version of allreduce.

Parameters:
  • tensor (torch.Tensor) –

  • op (ReduceOp) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.alltoall(send_tensor, recv_tensor, comm=None)

Each process scatters send_tensor to all processes associated with the communicator and return the gathered data in recv_tensor.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective, the size must be divisible by comm.nranks.

  • recv_tensor (torch.Tensor) – Output of the collective, must have equal size with send_tensor.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.alltoall_inplace(tensor, comm=None)

The in-place version of alltoall.

Parameters:
  • tensor (torch.Tensor) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.alltoall_v(send_tensor, send_counts, send_displs, recv_tensor, recv_counts, recv_displs, comm=None)

Each process scatters send_tensor to all processes associated with the communicator and return the gathered data in recv_tensor, each process may send a different amount of data and provide displacements for the input and output data.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective, the size must be divisible by comm.nranks.

  • send_counts (int) – integer array equal to the group size specifying the number of elements to send to each processor.

  • send_displs (int) – integer array (of length group size). Entry j specifies the displacement (relative to sendbuf from which to take the outgoing data destined for process j.

  • recv_tensor (torch.Tensor) – Output of the collective, must have equal size with send_tensor.

  • recv_counts (int) – integer array equal to the group size specifying the maximum number of elements that can be received from each processor.

  • recv_displs (int) – integer array (of length group size). Entry i specifies the displacement (relative to recvbuf at which to place the incoming data from process i.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.alltoall_v_inplace(tensor, counts, displs, comm=None)

The in-place version of alltoall_v.

Parameters:
  • tensor (torch.Tensor) –

  • counts (int) –

  • displs (int) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.barrier(comm=None)

Synchronizes all processes. This collective blocks processes until all processes associated with the communicator enters this function.

Parameters:

comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.broadcast(tensor, src=0, comm=None)

Broadcasts the tensor to all processes associated with the communicator.

tensor must have the same number of elements in all processes participating in the collective.

Parameters:
  • tensor (torch.Tensor) – Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise.

  • src (int) – Source rank. Default: 0.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.gather(send_tensor, recv_tensor, dst, comm=None)

Gathers send tensors from all processes associated with the communicator to recv_tensor in a single process.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have a size of comm.nranks * send_tensor.size() elements.

  • dst (int) – Destination rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.gather_inplace(tensor, count, dst, comm=None)

The in-place version of gather.

Parameters:
  • tensor (torch.Tensor) – Input and output of the collective, On the dst rank, it must have a size of comm.nranks * count elements. On non-dst ranks, its size must be equal to :attr:count.

  • count (int) – The per-rank data count to gather.

  • dst (int) – Destination rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.get_backend(model_name)
Parameters:

model_name (str) –

bagua.torch_api.get_local_rank()

Get the rank of current node.

Local rank is a unique identifier assigned to each process within a node. They are always consecutive integers ranging from 0 to local_size.

Returns:

The local rank of the node.

Return type:

int

bagua.torch_api.get_local_size()

Get the number of processes in the node.

Returns:

The local size of the node.

Return type:

int

bagua.torch_api.get_rank()

Get the rank of the default process group.

Rank is a unique identifier assigned to each process within the default process group. They are always consecutive integers ranging from 0 to world_size.

Returns:

The rank of the default process group.

Return type:

int

bagua.torch_api.get_world_size()

Get the number of processes in the default process group.

Returns:

The world size of the default process group.

Return type:

int

bagua.torch_api.init_process_group(store=None, rank=-1, world_size=-1, local_world_size=-1)

Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of Bagua.

Parameters:
  • store (Optional[torch.distributed.Store]) – Key/value store accessible to all workers, used to exchange connection/address information. If None, a TCP-based store will be created. Default: None.

  • rank (int) – Rank of the current process (it should be a number between 0 and world_size-1). Required if store is specified.

  • world_size (int) – Number of processes participating in the job. Required if store is specified.

  • local_world_size (int) – Number of processes per node. Required if store is specified.

Examples::
>>> import torch
>>> import bagua.torch_api as bagua
>>>
>>> torch.cuda.set_device(bagua.get_local_rank()) # THIS LINE IS IMPORTANT. See the notes below.
>>> bagua.init_process_group()
>>>
>>> model = torch.nn.Sequential(
...    torch.nn.Linear(D_in, H),
...    torch.nn.ReLU(),
...    torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...    model.parameters(),
...    lr=0.01,
...    momentum=0.9
...    )
>>> model = model.with_bagua([optimizer], ...)

Note

Each process should be associated to a CUDA device using torch.cuda.set_device(), before calling init_process_group. Otherwise, you may encounter the fatal runtime error: Rust cannot catch foreign exceptions error.

bagua.torch_api.recv(tensor, src, comm=None)

Receives a tensor synchronously.

Parameters:
  • tensor (torch.Tensor) – Tensor to fill with received data.

  • src (int) – Source rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.reduce(send_tensor, recv_tensor, dst, op=ReduceOp.SUM, comm=None)

Reduces the tensor data across all processes.

Only the process whit rank dst is going to receive the final result.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have the same size with send_tensor.

  • dst (int) – Destination rank.

  • op (ReduceOp) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.reduce_inplace(tensor, dst, op=ReduceOp.SUM, comm=None)

The in-place version of reduce.

Parameters:
  • tensor (torch.Tensor) –

  • dst (int) –

  • op (ReduceOp) –

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) –

bagua.torch_api.reduce_scatter(send_tensor, recv_tensor, op=ReduceOp.SUM, comm=None)

Reduces, then scatters send_tensor to all processes associated with the communicator.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective, must have a size of comm.nranks * recv_tensor.size() elements.

  • recv_tensor (torch.Tensor) – Output of the collective.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.reduce_scatter_inplace(tensor, op=ReduceOp.SUM, comm=None)

The in-place version of reduce_scatter.

Parameters:
  • tensor (torch.Tensor) – Input and output of the collective, the size must be divisible by comm.nranks.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.scatter(send_tensor, recv_tensor, src, comm=None)

Scatters send tensor to all processes associated with the communicator.

Parameters:
  • send_tensor (torch.Tensor) – Input of the collective, must have a size of comm.nranks * recv_tensor.size() elements.

  • recv_tensor (torch.Tensor) – Output of the collective.

  • src (int) – Source rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.scatter_inplace(tensor, count, src, comm=None)

The in-place version of scatter.

Parameters:
  • tensor (torch.Tensor) – Input and output of the collective, On the src rank, it must have a size of comm.nranks * count elements. On non-src ranks, its size must be equal to count.

  • count (int) – The per-rank data count to scatter.

  • src (int) – Source rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.send(tensor, dst, comm=None)

Sends a tensor to dst synchronously.

Parameters:
  • tensor (torch.Tensor) – Tensor to send.

  • dst (int) – Destination rank.

  • comm (Optional[bagua_core.BaguaSingleCommunicatorPy]) – A handle of the Bagua communicator to work on. By default, the global communicator of the default process group will be used.

bagua.torch_api.version