bagua.torch_api¶
The Bagua communication library PyTorch interface.
Subpackages¶
Submodules¶
Package Contents¶
- bagua.torch_api.init_process_group()¶
Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of bagua.
- Raises
RepeatedInitializationError – If you run this function repeatedly
- Examples::
>>> import bagua.torch_api as bagua >>> bagua.init_process_group() >>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model, optimizer = bagua_init(model, optimizer)
- bagua.torch_api.allreduce(tensor, op=dist.ReduceOp.SUM, comm=None)¶
Reduces the tensor data across all machines in such a way that all get the final result. After the call tensor is going to be bitwise identical in all processes.
- Parameters
tensor (torch.Tensor) – Input and output of the collective. The function operates in-place.
op (optional) – one of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None the global bagua communicator will be used. Defaults to None.
Examples
>>> from bagua.torch_api import allreduce >>> # All tensors below are of torch.int64 type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank >>> tensor tensor([1, 2]) # Rank 0 tensor([3, 4]) # Rank 1 >>> allreduce(tensor) >>> tensor tensor([4, 6]) # Rank 0 tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type. >>> # We have 2 process groups, 2 ranks. >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) >>> tensor tensor([1.+1.j, 2.+2.j]) # Rank 0 tensor([3.+3.j, 4.+4.j]) # Rank 1 >>> allreduce(tensor) >>> tensor tensor([4.+4.j, 6.+6.j]) # Rank 0 tensor([4.+4.j, 6.+6.j]) # Rank 1
- bagua.torch_api.broadcast(tensor, root=0, comm=None)¶
Broadcasts the tensor to the whole communicator.
tensor must have the same number of elements in all processes participating in the collective.
- Parameters
tensor (torch.Tensor) – Data to be sent if root is the rank of current process, and tensor to be used to save received data otherwise.
root (int, optional) – Source rank. Defaults to 0.
comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None, the global bagua communicator will be used. Defaults to None.
- class bagua.torch_api.BaguaModule¶
This class patches torch.nn.Module with several methods to enable Bagua functionalities.
- bagua_build_params(self)¶
Build tuple of (parameter_name, parameter) for all parameters that require grads and not in the
_bagua_params_and_buffers_to_ignore
attribute.- Returns
List[(str, torch.nn.Parameter)]
- with_bagua(self, optimizers, algorithm)¶
with_bagua enables easy distributed data parallel training on a torch.nn.Module.
- Parameters
optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.
algorithm (bagua.torch_api.algorithm.Algorithm) – Distributed algorithm used to do the actual communication and update.
- Returns
The original module, with Bagua related environments initialized.
Note
If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s
state_dict
(they are in"{module_name}.{param_name}"
format), then assign the list of keys toyour_module._bagua_params_and_buffers_to_ignore
.Examples:
>>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> model = model.with_bagua( ... [optimizer], ... GradientAllReduce() ... )
- class bagua.torch_api.BaguaTensor¶
This class patch torch.Tensor with additional methods.
- is_bagua_tensor(self)¶
- Return type
bool
- ensure_bagua_tensor(self, name=None)¶
Convert a PyTorch tensor or parameter to Bagua tensor inplace and return it. A Bagua tensor is required to use Bagua’s communication algorithms.
- Parameters
name (Optional[str]) – the unique name of the tensor
- Returns
The original tensor with Bagua tensor attributes initialized.
- to_bagua_tensor(self, name=None)¶
Create a new Bagua tensor from a PyTorch tensor or parameter and return it. The original tensor is not changed. A Bagua tensor is required to use Bagua’s communication algorithms.
- Parameters
name (Optional[str]) – the unique name of the tensor
- Returns
The new Bagua tensor sharing the same storage with the original tensor.
- bagua_backend_tensor(self)¶
- Returns
The raw Bagua backend tensor.
- Return type
bagua_core.BaguaTensorPy
- bagua_ensure_grad(self)¶
Return the gradient of current parameter. Create a zero gradient tensor if not exist.
- Return type
torch.Tensor
- bagua_mark_communication_ready(self)¶
Mark a Bagua tensor ready for scheduled operations execution.
- bagua_mark_communication_ready_without_synchronization(self)¶
Mark a Bagua tensor ready immediately, without CUDA event synchronization.
- bagua_set_storage(self, storage, storage_offset=0)¶
Sets the underlying storage using an existing torch.Storage.
- Parameters
storage (torch.Storage) – the storage to use
storage_offset (int) – the offset in the storage
- bagua.torch_api.get_rank()¶
Get the rank of current process group.
Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to
world_size
.- Returns
The rank of the process group.
- bagua.torch_api.get_world_size()¶
Get the number of processes in the current process group.
- Returns
The world size of the process group.
- bagua.torch_api.get_local_rank()¶
Get the rank of current node.
Local rank is a unique identifier assigned to each process within a node. They are always consecutive integers ranging from 0 to
local_size
.- Returns
The local rank of the node.
- bagua.torch_api.get_local_size()¶
Get the number of processes in the node.
- Returns
The local size of the node.