bagua.torch_api

The Bagua communication library PyTorch interface.

Subpackages

Submodules

Package Contents

bagua.torch_api.version
class bagua.torch_api.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

Variables
  • bagua_optimizers (List[torch.optim.Optimizer]) – The optimizers passed in by with_bagua.

  • bagua_algorithm (bagua.torch_api.algorithms.Algorithm) – The algorithm passed in by with_bagua.

  • parameters_to_ignore (List[str]) – The parameter names in "{module_name}.{param_name}" format to ignore when calling self.bagua_build_params().

  • bagua_train_step_counter (int) – Number of iterations in training mode.

  • bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.

bagua_build_params(self)

Build tuple of (parameter_name, parameter) for all parameters that require grads and not in the _bagua_params_and_buffers_to_ignore attribute.

Return type

List[Tuple[str, torch.nn.Parameter]]

with_bagua(self, optimizers, algorithm)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

Parameters
  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.

Returns

The original module, with Bagua related environments initialized.

Return type

BaguaModule

Note

If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.

Examples:

>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )
class bagua.torch_api.BaguaTensor

This class patch torch.Tensor with additional methods.

bagua_backend_tensor(self)
Returns

The raw Bagua backend tensor.

Return type

bagua_core.BaguaTensorPy

bagua_ensure_grad(self)

Return the gradient of current parameter. Create a zero gradient tensor if not exist.

Return type

torch.Tensor

bagua_mark_communication_ready(self)

Mark a Bagua tensor ready for scheduled operations execution.

bagua_mark_communication_ready_without_synchronization(self)

Mark a Bagua tensor ready immediately, without CUDA event synchronization.

bagua_set_storage(self, storage, storage_offset=0)

Sets the underlying storage using an existing torch.Storage.

Parameters
  • storage (torch.Storage) – The storage to use.

  • storage_offset (int) – The offset in the storage.

ensure_bagua_tensor(self, name=None, module_name=None)

Convert a PyTorch tensor or parameter to Bagua tensor inplace and return it. A Bagua tensor is required to use Bagua’s communication algorithms.

Parameters
  • name (Optional[str]) – The unique name of the tensor.

  • module_name (Optional[str]) – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

Returns

The original tensor with Bagua tensor attributes initialized.

is_bagua_tensor(self)
Return type

bool

to_bagua_tensor(self, name=None, module_name=None)

Create a new Bagua tensor from a PyTorch tensor or parameter and return it. The original tensor is not changed. A Bagua tensor is required to use Bagua’s communication algorithms.

Parameters
  • name (Optional[str]) – The unique name of the tensor.

  • module_name (Optional[str]) – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

Returns

The new Bagua tensor sharing the same storage with the original tensor.

class bagua.torch_api.ReduceOp

Bases: enum.IntEnum

An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, BXOR and AVG.

Initialize self. See help(type(self)) for accurate signature.

AVG = 10
BAND = 8
BOR = 7
BXOR = 9
MAX = 3
MIN = 2
PRODUCT = 1
SUM = 0
bagua.torch_api.allgather(send_tensor, recv_tensor, comm=None)

Gathers send tensors from all processes associated with the communicator into recv_tensor.

Parameters
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have a size of comm.nranks * send_tensor.size() elements.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.allgather_inplace(tensor, comm=None)

The in-place version of allgather.

Parameters

comm (bagua_core.BaguaSingleCommunicatorPy) –

bagua.torch_api.allreduce(send_tensor, recv_tensor, op=ReduceOp.SUM, comm=None)

Reduces the tensor data across all processes associated with the communicator in such a way that all get the final result. After the call recv_tensor is going to be bitwise identical in all processes.

Parameters
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have the same size with send_tensor.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

Examples:

>>> from bagua.torch_api import allreduce
>>>
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> recv_tensor = torch.zeros(2, dtype=torch.int64)
>>> send_tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1

>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> send_tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> recv_tensor = torch.zeros(2, dtype=torch.cfloat)
>>> send_tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> allreduce(send_tensor, recv_tensor)
>>> recv_tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1
bagua.torch_api.allreduce_inplace(tensor, op=ReduceOp.SUM, comm=None)

The in-place version of allreduce.

Parameters
  • op (ReduceOp) –

  • comm (bagua_core.BaguaSingleCommunicatorPy) –

bagua.torch_api.alltoall(send_tensor, recv_tensor, comm=None)

Each process scatters send_tensor to all processes associated with the communicator and return the gathered data in recv_tensor.

Parameters
  • send_tensor (torch.Tensor) – Input of the collective, the size must be divisible by comm.nranks.

  • recv_tensor (torch.Tensor) – Output of the collective, must have equal size with send_tensor.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.alltoall_inplace(tensor, comm=None)

The in-place version of alltoall.

Parameters

comm (bagua_core.BaguaSingleCommunicatorPy) –

bagua.torch_api.broadcast(tensor, src=0, comm=None)

Broadcasts the tensor to all processes associated with the communicator.

tensor must have the same number of elements in all processes participating in the collective.

Parameters
  • tensor (torch.Tensor) – Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise.

  • src (int, optional) – Source rank. Default: 0.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None, the global Bagua communicator will be used. Default: None.

bagua.torch_api.gather(send_tensor, recv_tensor, dst, comm=None)

Gathers send tensors from all processes associated with the communicator to recv_tensor in a single process.

Parameters
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have a size of comm.nranks * send_tensor.size() elements.

  • dst (int) – Destination rank.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.gather_inplace(tensor, count, dst, comm=None)

The in-place version of gather.

Parameters
  • tensor (torch.Tensor) – Input and output of the collective, On the dst rank, it must have a size of comm.nranks * count elements. On non-dst ranks, its size must be equal to :attr:count.

  • count (int) – The per-rank data count to gather.

  • dst (int) – Destination rank.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.get_backend(model_name)
Parameters

model_name (str) –

bagua.torch_api.get_local_rank()

Get the rank of current node.

Local rank is a unique identifier assigned to each process within a node. They are always consecutive integers ranging from 0 to local_size.

Returns

The local rank of the node.

Return type

int

bagua.torch_api.get_local_size()

Get the number of processes in the node.

Returns

The local size of the node.

Return type

int

bagua.torch_api.get_rank()

Get the rank of current process group.

Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to world_size.

Returns

The rank of the process group.

Return type

int

bagua.torch_api.get_world_size()

Get the number of processes in the current process group.

Returns

The world size of the process group.

Return type

int

bagua.torch_api.init_process_group()

Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of Bagua.

Examples::
>>> import torch
>>> import bagua.torch_api as bagua
>>>
>>> torch.cuda.set_device(bagua.get_local_rank())
>>> bagua.init_process_group()
>>>
>>> model = torch.nn.Sequential(
...    torch.nn.Linear(D_in, H),
...    torch.nn.ReLU(),
...    torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...    model.parameters(),
...    lr=0.01,
...    momentum=0.9
...    )
>>> model = model.with_bagua([optimizer], ...)
bagua.torch_api.recv(tensor, src, comm=None)

Receives a tensor synchronously.

Parameters
  • tensor (torch.Tensor) – Tensor to fill with received data.

  • src (int) – Source rank.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None, the global Bagua communicator will be used.

bagua.torch_api.reduce(send_tensor, recv_tensor, dst, op=ReduceOp.SUM, comm=None)

Reduces the tensor data across all processes.

Only the process whit rank dst is going to receive the final result.

Parameters
  • send_tensor (torch.Tensor) – Input of the collective.

  • recv_tensor (torch.Tensor) – Output of the collective, must have the same size with send_tensor.

  • dst (int) – Destination rank.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.reduce_inplace(tensor, dst, op=ReduceOp.SUM, comm=None)

The in-place version of reduce.

Parameters
  • op (ReduceOp) –

  • comm (bagua_core.BaguaSingleCommunicatorPy) –

bagua.torch_api.reduce_scatter(send_tensor, recv_tensor, op=ReduceOp.SUM, comm=None)

Reduces, then scatters send_tensor to all processes associated with the communicator.

Parameters
  • send_tensor (torch.Tensor) – Input of the collective, must have a size of comm.nranks * recv_tensor.size() elements.

  • recv_tensor (torch.Tensor) – Output of the collective.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.reduce_scatter_inplace(tensor, op=ReduceOp.SUM, comm=None)

The in-place version of reduce_scatter.

Parameters
  • tensor (torch.Tensor) – Input and output of the collective, the size must be divisible by comm.nranks.

  • op (ReduceOp, optional) – One of the values from ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.scatter(send_tensor, recv_tensor, src, comm=None)

Scatters send tensor to all processes associated with the communicator.

Parameters
  • send_tensor (torch.Tensor) – Input of the collective, must have a size of comm.nranks * recv_tensor.size() elements.

  • recv_tensor (torch.Tensor) – Output of the collective.

  • src (int) – Source rank.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.scatter_inplace(tensor, count, src, comm=None)

The in-place version of scatter.

Parameters
  • tensor (torch.Tensor) – Input and output of the collective, On the src rank, it must have a size of comm.nranks * count elements. On non-src ranks, its size must be equal to count.

  • count (int) – The per-rank data count to scatter.

  • src (int) – Source rank.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None the global Bagua communicator will be used. Default: None.

bagua.torch_api.send(tensor, dst, comm=None)

Sends a tensor to dst synchronously.

Parameters
  • tensor (torch.Tensor) – Tensor to send.

  • dst (int) – Destination rank.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The Bagua communicator to work on. If None, the global Bagua communicator will be used.