Bagua

This website contains Bagua API documentation. See tutorials if you need step by step instructions on how to use Bagua.

bagua

Bagua is a communication library developed by Kuaishou Technology and DS3 Lab for deep learning.

See tutorials for Bagua’s rationale and benchmark.

Subpackages

bagua.torch_api

The Bagua communication library PyTorch interface.

Subpackages
bagua.torch_api.algorithms
Submodules
bagua.torch_api.algorithms.base
Module Contents
class bagua.torch_api.algorithms.base.Algorithm

This is the base class that all Bagua algorithms inherit.

It provides methods that can be override to implement different kinds of distributed algorithms.

init_backward_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed on every parameter’s gradient computation completion.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes the name of a parameter (as in torch.nn.Module.named_parameters()) and the parameter itself.

init_forward_pre_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed before the forward process.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes the model’s input.

init_operations(self, bagua_module, bucket)

Given a BaguaModule, and a Bagua bucket, register operations to be executed on the bucket.

Parameters
init_post_backward_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed when the backward pass is done.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes no argument.

init_post_optimizer_step_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed when the optimizer.step() is done.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes the optimizer that is called step().

init_tensors(self, bagua_module)

Given a BaguaModule, return Bagua tensors to be used in Bagua for later operations.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A list of Bagua tensors for communication.

Return type

List[bagua.torch_api.tensor.BaguaTensor]

need_reset(self)
Returns

True if all initialization methods of the current algorithms should

Return type

bool

be called again. This is useful for algorithms that has multiple stages where each stage needs different initializations.

tensors_to_buckets(self, tensors)

Given the bucketing suggestion from Bagua, return the actual Bagua buckets. The default implementation follows the suggestion to do the bucketing.

Parameters

tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) – Bagua tensors grouped in different lists, representing Bagua’s suggestion on how to bucketing the tensors.

Returns

A list of Bagua buckets.

Return type

List[bagua.torch_api.bucket.BaguaBucket]

bagua.torch_api.algorithms.bytegrad
Module Contents
class bagua.torch_api.algorithms.bytegrad.ByteGradAlgorithm(average=True)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the ByteGrad algorithm.

Parameters

average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.

init_operations(self, bagua_module, bucket)
Parameters
tensors_to_buckets(self, tensors)

Given the bucketing suggestion from Bagua, return the actual Bagua buckets. The default implementation follows the suggestion to do the bucketing.

Parameters

tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) – Bagua tensors grouped in different lists, representing Bagua’s suggestion on how to bucketing the tensors.

Returns

A list of Bagua buckets.

Return type

List[bagua.torch_api.bucket.BaguaBucket]

bagua.torch_api.algorithms.decentralized
Module Contents
class bagua.torch_api.algorithms.decentralized.DecentralizedAlgorithm(peer_selection_mode='all', compression=None, communication_interval=1)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the Decentralized algorithm.

Parameters
  • peer_selection_mode (str) – Can be “all” or “shift_one”. “all” means all workers’ weights are averaged in each communication step. “shift_one” means each worker selects a different peer to do weights average in each communication step.

  • compression (str) – Not supported yet.

  • communication_interval (int) – Number of iterations between two communication steps.

init_backward_hook(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

init_forward_pre_hook(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

init_operations(self, bagua_module, bucket)
Parameters
init_post_backward_hook(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

init_tensors(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

Return type

List[bagua.torch_api.tensor.BaguaTensor]

bagua.torch_api.algorithms.gradient_allreduce
Module Contents
class bagua.torch_api.algorithms.gradient_allreduce.GradientAllReduceAlgorithm(hierarchical=False, average=True)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the GradientAllReduce algorithm.

Parameters
  • hierarchical (bool) – Enable hierarchical communication.

  • average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.

init_operations(self, bagua_module, bucket)
Parameters
bagua.torch_api.algorithms.q_adam
Module Contents
class bagua.torch_api.algorithms.q_adam.QAdamAlgorithm(q_adam_optimizer, hierarchical=True)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the QAdam Algorithm .

Parameters
  • q_adam_optimizer – A QAdamOptimizer initialized with model parameters.

  • hierarchical – Enable hierarchical communication.

init_backward_hook(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

init_operations(self, bagua_module, bucket)
Parameters
init_tensors(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

need_reset(self)
tensors_to_buckets(self, tensors)
Parameters

tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) –

Return type

List[bagua.torch_api.bucket.BaguaBucket]

class bagua.torch_api.algorithms.q_adam.QAdamOptimizer(params, lr=0.001, warmup_steps=100, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

Bases: torch.optim.optimizer.Optimizer

Create a dedicated optimizer used for QAdam algorithm.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr – learning rate

  • warmup_steps – number of steps to do warm up in the begining of training.

  • betas – coefficients used for computing running averages of gradient and its square

  • eps – term added to the denominator to improve numerical stability

  • weight_decay – weight decay (L2 penalty)

step(self, closure=None)
Package Contents
class bagua.torch_api.algorithms.Algorithm

This is the base class that all Bagua algorithms inherit.

It provides methods that can be override to implement different kinds of distributed algorithms.

init_backward_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed on every parameter’s gradient computation completion.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes the name of a parameter (as in torch.nn.Module.named_parameters()) and the parameter itself.

init_forward_pre_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed before the forward process.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes the model’s input.

init_operations(self, bagua_module, bucket)

Given a BaguaModule, and a Bagua bucket, register operations to be executed on the bucket.

Parameters
init_post_backward_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed when the backward pass is done.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes no argument.

init_post_optimizer_step_hook(self, bagua_module)

Given a BaguaModule, return a hook function that will be executed when the optimizer.step() is done.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A function that takes the optimizer that is called step().

init_tensors(self, bagua_module)

Given a BaguaModule, return Bagua tensors to be used in Bagua for later operations.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua(...) method.

Returns

A list of Bagua tensors for communication.

Return type

List[bagua.torch_api.tensor.BaguaTensor]

need_reset(self)
Returns

True if all initialization methods of the current algorithms should

Return type

bool

be called again. This is useful for algorithms that has multiple stages where each stage needs different initializations.

tensors_to_buckets(self, tensors)

Given the bucketing suggestion from Bagua, return the actual Bagua buckets. The default implementation follows the suggestion to do the bucketing.

Parameters

tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) – Bagua tensors grouped in different lists, representing Bagua’s suggestion on how to bucketing the tensors.

Returns

A list of Bagua buckets.

Return type

List[bagua.torch_api.bucket.BaguaBucket]

bagua.torch_api.contrib
Submodules
bagua.torch_api.contrib.fused_optimizer
Module Contents
class bagua.torch_api.contrib.fused_optimizer.FusedOptimizer(optimizer, do_flatten=False)

Bases: torch.optim.Optimizer

Convert any optimizer into a fused optimizer.

This fused optimizer fuses multiple module parameter update kernel launches into one or a few, by flattening parameter tensors into one or more contiguous buckets.

It can be used in conjunction with bagua.torch_api.bagua_init. In this case, Bagua will do the fusions automatically, otherwise, you need to explicitly pass do_flatten=True.

Parameters
  • optimizer (torch.optim.Optimizer) – Any PyTorch optimizer.

  • do_flatten (bool) – Whether to flatten the parameters. Default: False.

Returns

Fused optimizer.

Example::

To use in conjunction with bagua.torch_api.bagua_init:

>>> optimizer = torch.optim.Adadelta(model.parameters(), ....)
>>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer)
>>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm())

To use alone or with torch.nn.parallel.DistributedDataParallel, set do_flatten to be True:

>>> optimizer = torch.optim.Adadelta(model.parameters(), ....)
>>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer, do_flatten=True)
step(self, closure=None)

Performs a single optimization step (parameter update).

Parameters

closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

bagua.torch_api.contrib.load_balancing_data_loader
Module Contents
class bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedBatchSampler(sampler, batch_fn, drop_last=False)

Bases: torch.utils.data.sampler.Sampler

Wraps another load balance sampler to yield variable sized mini-batches.

Parameters
  • sampler (LoadBalancingDistributedSampler) – Load balance sampler.

  • batch_fn (Callable) – Callable to yield mini-batch indices.

  • drop_last (bool) – If True, the sampler will drop the last few batches exceeding the least number of batches among replicas, otherwises, the number of batches on each replica will be padded to the same.

batch_fn will have the signature of def batch_fn(indices: List[int]) -> List[List[int]].

Example:

>>> from bagua.torch_api.contrib import LoadBalancingDistributedSampler, \
...     LoadBalancingDistributedBatchSampler
>>>
>>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
>>> batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
>>> loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
...     batch_sampler.set_epoch(epoch)
...     train(loader)
set_epoch(self, epoch)

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

class bagua.torch_api.contrib.load_balancing_data_loader.LoadBalancingDistributedSampler(dataset, complexity_fn, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, random_level=0)

Bases: torch.utils.data.sampler.Sampler

Sampler that restricts data loading to a subset of the dataset.

This sampler use a complexity_fn to calculate each sample’s computational complexity and make each batch get similar computation complexity.

This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from straggler problem.

The usage is similar to torch.utils.data.DistributedSampler, where each process loads a subset of the original dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size.

Parameters
  • dataset – Dataset used for sampling.

  • complexity_fn (Callable) – A function whose input is a sample and output is an integer as a measure of the computational complexity of the sample.

  • num_replicas (int, optional) – Number of processes participating in distributed training. By default, world_size is retrieved from the current distributed group.

  • rank (int, optional) – Rank of the current process within num_replicas. By default, rank is retrieved from the current distributed group.

  • shuffle (bool, optional) – If True (default), sampler will shuffle the indices.

  • seed (int, optional) – random seed used to shuffle the sampler if shuffle=True. This number should be identical across all processes in the distributed group. Default: 0.

  • drop_last (bool, optional) – if True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: False.

  • random_level (float, optional) – A float varies from 0 and 1 that controls the extent of load balance. 0 means the best load balance, while 1 means the opposite.

Warning

In distributed mode, calling the set_epoch method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

Example::

Define your complexity_fn, which accepts a dataset sample as its input and produces an integer as the sample’s computational complexity.

>>> dataset = torch.utils.data.TensorDataset(torch.randn(n, 2), torch.randperm(n))
>>> complexity_fn = lambda x: x[1]

Below is the usage of LoadBalancingDistributedSampler and DataLoader:

>>> sampler = bagua.torch_api.contrib.LoadBalancingDistributedSampler(
...     dataset,
...     complexity_fn=complexity_fn) if is_distributed else None
>>> loader = torch.utils.data.DataLoader(dataset,
...     shuffle=(sampler is None),
...     sampler=sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)
set_epoch(self, epoch)

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

Package Contents
class bagua.torch_api.contrib.FusedOptimizer(optimizer, do_flatten=False)

Bases: torch.optim.Optimizer

Convert any optimizer into a fused optimizer.

This fused optimizer fuses multiple module parameter update kernel launches into one or a few, by flattening parameter tensors into one or more contiguous buckets.

It can be used in conjunction with bagua.torch_api.bagua_init. In this case, Bagua will do the fusions automatically, otherwise, you need to explicitly pass do_flatten=True.

Parameters
  • optimizer (torch.optim.Optimizer) – Any PyTorch optimizer.

  • do_flatten (bool) – Whether to flatten the parameters. Default: False.

Returns

Fused optimizer.

Example::

To use in conjunction with bagua.torch_api.bagua_init:

>>> optimizer = torch.optim.Adadelta(model.parameters(), ....)
>>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer)
>>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm())

To use alone or with torch.nn.parallel.DistributedDataParallel, set do_flatten to be True:

>>> optimizer = torch.optim.Adadelta(model.parameters(), ....)
>>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer, do_flatten=True)
step(self, closure=None)

Performs a single optimization step (parameter update).

Parameters

closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.

class bagua.torch_api.contrib.LoadBalancingDistributedBatchSampler(sampler, batch_fn, drop_last=False)

Bases: torch.utils.data.sampler.Sampler

Wraps another load balance sampler to yield variable sized mini-batches.

Parameters
  • sampler (LoadBalancingDistributedSampler) – Load balance sampler.

  • batch_fn (Callable) – Callable to yield mini-batch indices.

  • drop_last (bool) – If True, the sampler will drop the last few batches exceeding the least number of batches among replicas, otherwises, the number of batches on each replica will be padded to the same.

batch_fn will have the signature of def batch_fn(indices: List[int]) -> List[List[int]].

Example:

>>> from bagua.torch_api.contrib import LoadBalancingDistributedSampler, \
...     LoadBalancingDistributedBatchSampler
>>>
>>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
>>> batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
>>> loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
...     batch_sampler.set_epoch(epoch)
...     train(loader)
set_epoch(self, epoch)

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

class bagua.torch_api.contrib.LoadBalancingDistributedSampler(dataset, complexity_fn, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, random_level=0)

Bases: torch.utils.data.sampler.Sampler

Sampler that restricts data loading to a subset of the dataset.

This sampler use a complexity_fn to calculate each sample’s computational complexity and make each batch get similar computation complexity.

This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from straggler problem.

The usage is similar to torch.utils.data.DistributedSampler, where each process loads a subset of the original dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size.

Parameters
  • dataset – Dataset used for sampling.

  • complexity_fn (Callable) – A function whose input is a sample and output is an integer as a measure of the computational complexity of the sample.

  • num_replicas (int, optional) – Number of processes participating in distributed training. By default, world_size is retrieved from the current distributed group.

  • rank (int, optional) – Rank of the current process within num_replicas. By default, rank is retrieved from the current distributed group.

  • shuffle (bool, optional) – If True (default), sampler will shuffle the indices.

  • seed (int, optional) – random seed used to shuffle the sampler if shuffle=True. This number should be identical across all processes in the distributed group. Default: 0.

  • drop_last (bool, optional) – if True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: False.

  • random_level (float, optional) – A float varies from 0 and 1 that controls the extent of load balance. 0 means the best load balance, while 1 means the opposite.

Warning

In distributed mode, calling the set_epoch method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

Example::

Define your complexity_fn, which accepts a dataset sample as its input and produces an integer as the sample’s computational complexity.

>>> dataset = torch.utils.data.TensorDataset(torch.randn(n, 2), torch.randperm(n))
>>> complexity_fn = lambda x: x[1]

Below is the usage of LoadBalancingDistributedSampler and DataLoader:

>>> sampler = bagua.torch_api.contrib.LoadBalancingDistributedSampler(
...     dataset,
...     complexity_fn=complexity_fn) if is_distributed else None
>>> loader = torch.utils.data.DataLoader(dataset,
...     shuffle=(sampler is None),
...     sampler=sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)
set_epoch(self, epoch)

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

Submodules
bagua.torch_api.bucket
Module Contents
class bagua.torch_api.bucket.BaguaBucket(tensors, name, flatten, alignment=1)

Create a Bagua bucket with a list of Bagua tensors.

Parameters
  • tensors – A list of Bagua tensors to be put in the bucket.

  • name – The unique name of the bucket.

  • flatten – If True, flatten the input tensors so that they are contiguous in memory.

  • alignment – If alignment > 1, Bagua will create a padding tensor to the bucket so that the total number of elements in the bucket divides the given alignment.

name

The bucket’s name.

tensors

The tensors contained within the bucket.

append_centralized_synchronous_op(self, hierarchical=False, average=True, scattergather=False, compression=None)

Append a centralized synchronous operation to a bucket. It will sum or average the tensors in the bucket for all workers.

The operations will be executed by the Bagua backend in the order they are appended when all the tensors within the bucket are marked ready.

Parameters
  • hierarchical (bool) – Enable hierarchical communication. Which means the GPUs on the same machine will communicate will each other first. After that, machines do inter-node communication. This can boost performance when the inter-node communication cost is high.

  • average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.

  • scattergather (bool) – If true, the communication between workers are done with scatter gather instead of allreduce. This is required for using compression.

  • compression (Optional[str]) – If not None, the tensors will be compressed for communication. Currently “MinMaxUInt8” is supported.

Returns

The bucket itself.

Return type

BaguaBucket

append_decentralized_synchronous_op(self, hierarchical=True, peer_selection_mode='all', communication_interval=1)

Append a decentralized synchronous operation to a bucket. It will do gossipy style model averaging among workers.

The operations will be executed by the Bagua backend in the order they are appended when all the tensors within the bucket are marked ready.

Parameters
  • hierarchical (bool) – Enable hierarchical communication. Which means the GPUs on the same machine will communicate will each other first. After that, machines do inter-node communication. This can boost performance when the inter-node communication cost is high.

  • peer_selection_mode (str) – Can be “all” or “shift_one”. “all” means all workers’ weights are averaged in each communication step. “shift_one” means each worker selects a different peer to do weights average in each communication step.

  • communication_interval (int) – Number of iterations between two communication steps.

Returns

The bucket itself.

Return type

BaguaBucket

append_python_op(self, python_function)

Append a Python operation to a bucket. A Python operation is a Python function that takes the bucket’s name and returns None. It can do arbitrary things within the function body.

The operations will be executed by the Bagua backend in the order they are appended when all the tensors within the bucket are marked ready.

Parameters

python_function (Callable[[str], None]) – The Python operation function.

Returns

The bucket itself.

Return type

BaguaBucket

bytes(self)

Returns the total number of bytes occupied by the bucket.

Returns

number of bucket bytes

Return type

int

check_flatten(self)
Returns

True if the bucket’s tensors are contiguous in memory.

Return type

bool

clear_ops(self)

Clear the previously appended operations.

Return type

BaguaBucket

bagua.torch_api.communication
Module Contents
bagua.torch_api.communication.allreduce(tensor, op=dist.ReduceOp.SUM, comm=None)

Reduces the tensor data across all machines in such a way that all get the final result. After the call tensor is going to be bitwise identical in all processes.

Parameters
  • tensor (torch.Tensor) – Input and output of the collective. The function operates in-place.

  • op (optional) – one of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None the global bagua communicator will be used. Defaults to None.

Examples

>>> from bagua.torch_api import allreduce
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> allreduce(tensor)
>>> tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> allreduce(tensor)
>>> tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1
bagua.torch_api.communication.broadcast(tensor, root=0, comm=None)

Broadcasts the tensor to the whole communicator.

tensor must have the same number of elements in all processes participating in the collective.

Parameters
  • tensor (torch.Tensor) – Data to be sent if root is the rank of current process, and tensor to be used to save received data otherwise.

  • root (int, optional) – Source rank. Defaults to 0.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None, the global bagua communicator will be used. Defaults to None.

bagua.torch_api.communication.get_backend(model_name)
Parameters

model_name (str) –

bagua.torch_api.communication.init_process_group()

Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of bagua.

Raises

RepeatedInitializationError – If you run this function repeatedly

Examples::
>>> import bagua.torch_api as bagua
>>> bagua.init_process_group()
>>> model = torch.nn.Sequential(
...    torch.nn.Linear(D_in, H),
...    torch.nn.ReLU(),
...    torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...    model.parameters(),
...    lr=0.01,
...    momentum=0.9
...    )
>>> model, optimizer = bagua_init(model, optimizer)
bagua.torch_api.communication.reduce(tensor, dst, op=dist.ReduceOp.SUM, comm=None)

Reduces the tensor across all processes.

Only the process whit rank dst is going to receive the final result.

Parameters
  • tensor (torch.Tensor) – Input and output of the collective. The function operates in-place.

  • dst (int) – Destination rank

  • op (optional) – one of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None the global bagua communicator will be used. Defaults to None.

bagua.torch_api.distributed
Module Contents
class bagua.torch_api.distributed.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

Variables
  • bagua_optimizers (List[torch.optim.Optimizer]) – The optimizers passed in by with_bagua(...).

  • bagua_algorithm (bagua.torch_api.algorithms.Algorithm) – The algorithm passed in by with_bagua(...).

  • parameters_to_ignore (List[str]) – The parameter names in "{module_name}.{param_name}" format to ignore when calling self.bagua_build_params().

  • bagua_train_step_counter (int) – Number of iterations in training mode

  • bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.

bagua_build_params(self)

Build tuple of (parameter_name, parameter) for all parameters that require grads and not in the _bagua_params_and_buffers_to_ignore attribute.

Return type

List[Tuple[str, torch.nn.Parameter]]

with_bagua(self, optimizers, algorithm)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

Parameters
  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.

Returns

The original module, with Bagua related environments initialized.

Return type

BaguaModule

Note

If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.

Examples:

>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )
bagua.torch_api.env
Module Contents
bagua.torch_api.env.get_default_bucket_size()

Get default communication bucket byte size.

Returns

default bucket size

Return type

int

bagua.torch_api.env.get_local_rank()

Get the rank of current node.

Local rank is a unique identifier assigned to each process within a node. They are always consecutive integers ranging from 0 to local_size.

Returns

The local rank of the node.

bagua.torch_api.env.get_local_size()

Get the number of processes in the node.

Returns

The local size of the node.

bagua.torch_api.env.get_rank()

Get the rank of current process group.

Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to world_size.

Returns

The rank of the process group.

bagua.torch_api.env.get_world_size()

Get the number of processes in the current process group.

Returns

The world size of the process group.

bagua.torch_api.tensor
Module Contents
class bagua.torch_api.tensor.BaguaTensor

This class patch torch.Tensor with additional methods.

bagua_backend_tensor(self)
Returns

The raw Bagua backend tensor.

Return type

bagua_core.BaguaTensorPy

bagua_ensure_grad(self)

Return the gradient of current parameter. Create a zero gradient tensor if not exist.

Return type

torch.Tensor

bagua_mark_communication_ready(self)

Mark a Bagua tensor ready for scheduled operations execution.

bagua_mark_communication_ready_without_synchronization(self)

Mark a Bagua tensor ready immediately, without CUDA event synchronization.

bagua_set_storage(self, storage, storage_offset=0)

Sets the underlying storage using an existing torch.Storage.

Parameters
  • storage (torch.Storage) – the storage to use

  • storage_offset (int) – the offset in the storage

ensure_bagua_tensor(self, name=None, module_name=None)

Convert a PyTorch tensor or parameter to Bagua tensor inplace and return it. A Bagua tensor is required to use Bagua’s communication algorithms.

Parameters
  • name (Optional[str]) – the unique name of the tensor

  • model_name – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

  • module_name (Optional[str]) –

Returns

The original tensor with Bagua tensor attributes initialized.

is_bagua_tensor(self)
Return type

bool

to_bagua_tensor(self, name=None, module_name=None)

Create a new Bagua tensor from a PyTorch tensor or parameter and return it. The original tensor is not changed. A Bagua tensor is required to use Bagua’s communication algorithms.

Parameters
  • name (Optional[str]) – the unique name of the tensor

  • model_name – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

  • module_name (Optional[str]) –

Returns

The new Bagua tensor sharing the same storage with the original tensor.

Package Contents
class bagua.torch_api.BaguaModule

This class patches torch.nn.Module with several methods to enable Bagua functionalities.

Variables
  • bagua_optimizers (List[torch.optim.Optimizer]) – The optimizers passed in by with_bagua(...).

  • bagua_algorithm (bagua.torch_api.algorithms.Algorithm) – The algorithm passed in by with_bagua(...).

  • parameters_to_ignore (List[str]) – The parameter names in "{module_name}.{param_name}" format to ignore when calling self.bagua_build_params().

  • bagua_train_step_counter (int) – Number of iterations in training mode

  • bagua_buckets (List[bagua.torch_api.bucket.BaguaBucket]) – All Bagua buckets in a list.

bagua_build_params(self)

Build tuple of (parameter_name, parameter) for all parameters that require grads and not in the _bagua_params_and_buffers_to_ignore attribute.

Return type

List[Tuple[str, torch.nn.Parameter]]

with_bagua(self, optimizers, algorithm)

with_bagua enables easy distributed data parallel training on a torch.nn.Module.

Parameters
  • optimizers (List[torch.optim.Optimizer]) – Optimizer(s) used by the module. It can contain one or more PyTorch optimizers.

  • algorithm (bagua.torch_api.algorithms.Algorithm) – Distributed algorithm used to do the actual communication and update.

Returns

The original module, with Bagua related environments initialized.

Return type

BaguaModule

Note

If we want to ignore some layers for communication, we can first check these layer’s corresponding keys in the module’s state_dict (they are in "{module_name}.{param_name}" format), then assign the list of keys to your_module._bagua_params_and_buffers_to_ignore.

Examples:

>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> model = model.with_bagua(
...      [optimizer],
...      GradientAllReduce()
...    )
class bagua.torch_api.BaguaTensor

This class patch torch.Tensor with additional methods.

bagua_backend_tensor(self)
Returns

The raw Bagua backend tensor.

Return type

bagua_core.BaguaTensorPy

bagua_ensure_grad(self)

Return the gradient of current parameter. Create a zero gradient tensor if not exist.

Return type

torch.Tensor

bagua_mark_communication_ready(self)

Mark a Bagua tensor ready for scheduled operations execution.

bagua_mark_communication_ready_without_synchronization(self)

Mark a Bagua tensor ready immediately, without CUDA event synchronization.

bagua_set_storage(self, storage, storage_offset=0)

Sets the underlying storage using an existing torch.Storage.

Parameters
  • storage (torch.Storage) – the storage to use

  • storage_offset (int) – the offset in the storage

ensure_bagua_tensor(self, name=None, module_name=None)

Convert a PyTorch tensor or parameter to Bagua tensor inplace and return it. A Bagua tensor is required to use Bagua’s communication algorithms.

Parameters
  • name (Optional[str]) – the unique name of the tensor

  • model_name – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

  • module_name (Optional[str]) –

Returns

The original tensor with Bagua tensor attributes initialized.

is_bagua_tensor(self)
Return type

bool

to_bagua_tensor(self, name=None, module_name=None)

Create a new Bagua tensor from a PyTorch tensor or parameter and return it. The original tensor is not changed. A Bagua tensor is required to use Bagua’s communication algorithms.

Parameters
  • name (Optional[str]) – the unique name of the tensor

  • model_name – The name of the model of which the tensor belongs to. The model name can be acquired using model.bagua_module_name. This is required to call bagua_mark_communication_ready related methods.

  • module_name (Optional[str]) –

Returns

The new Bagua tensor sharing the same storage with the original tensor.

bagua.torch_api.allreduce(tensor, op=dist.ReduceOp.SUM, comm=None)

Reduces the tensor data across all machines in such a way that all get the final result. After the call tensor is going to be bitwise identical in all processes.

Parameters
  • tensor (torch.Tensor) – Input and output of the collective. The function operates in-place.

  • op (optional) – one of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None the global bagua communicator will be used. Defaults to None.

Examples

>>> from bagua.torch_api import allreduce
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> allreduce(tensor)
>>> tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> allreduce(tensor)
>>> tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1
bagua.torch_api.broadcast(tensor, root=0, comm=None)

Broadcasts the tensor to the whole communicator.

tensor must have the same number of elements in all processes participating in the collective.

Parameters
  • tensor (torch.Tensor) – Data to be sent if root is the rank of current process, and tensor to be used to save received data otherwise.

  • root (int, optional) – Source rank. Defaults to 0.

  • comm (B.BaguaSingleCommunicatorPy, optional) – The bagua communicator to work on. If None, the global bagua communicator will be used. Defaults to None.

bagua.torch_api.get_local_rank()

Get the rank of current node.

Local rank is a unique identifier assigned to each process within a node. They are always consecutive integers ranging from 0 to local_size.

Returns

The local rank of the node.

bagua.torch_api.get_local_size()

Get the number of processes in the node.

Returns

The local size of the node.

bagua.torch_api.get_rank()

Get the rank of current process group.

Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to world_size.

Returns

The rank of the process group.

bagua.torch_api.get_world_size()

Get the number of processes in the current process group.

Returns

The world size of the process group.

bagua.torch_api.init_process_group()

Initializes the PyTorch builtin distributed process group, and this will also initialize the distributed package, should be executed before all the APIs of bagua.

Raises

RepeatedInitializationError – If you run this function repeatedly

Examples::
>>> import bagua.torch_api as bagua
>>> bagua.init_process_group()
>>> model = torch.nn.Sequential(
...    torch.nn.Linear(D_in, H),
...    torch.nn.ReLU(),
...    torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...    model.parameters(),
...    lr=0.01,
...    momentum=0.9
...    )
>>> model, optimizer = bagua_init(model, optimizer)