bagua.torch_api.contrib¶
Submodules¶
Package Contents¶
- class bagua.torch_api.contrib.FusedOptimizer(optimizer, do_flatten=False)¶
Bases:
torch.optim.Optimizer
Convert any optimizer into a fused optimizer.
This fused optimizer fuses multiple module parameter update kernel launches into one or a few, by flattening parameter tensors into one or more contiguous buckets.
It can be used in conjunction with
bagua.torch_api.bagua_init
. In this case, Bagua will do the fusions automatically, otherwise, you need to explicitly passdo_flatten=True
.- Parameters
optimizer (torch.optim.Optimizer) – Any PyTorch optimizer.
do_flatten (bool) – Whether to flatten the parameters. Default:
False
.
- Returns
Fused optimizer.
- Example::
To use in conjunction with
bagua.torch_api.bagua_init
:>>> optimizer = torch.optim.Adadelta(model.parameters(), ....) >>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer) >>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm())
To use alone or with
torch.nn.parallel.DistributedDataParallel
, set do_flatten to beTrue
:>>> optimizer = torch.optim.Adadelta(model.parameters(), ....) >>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer, do_flatten=True)
- step(self, closure=None)¶
Performs a single optimization step (parameter update).
- Parameters
closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.
- class bagua.torch_api.contrib.LoadBalancingDistributedSampler(dataset, complexity_fn, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, random_level=0)¶
Bases:
torch.utils.data.sampler.Sampler
Sampler that restricts data loading to a subset of the dataset.
This sampler use a complexity_fn to calculate each sample’s computational complexity and make each batch get similar computation complexity.
This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from straggler problem.
The usage is similar to
torch.utils.data.DistributedSampler
, where each process loads a subset of the original dataset that is exclusive to it.Note
Dataset is assumed to be of constant size.
- Parameters
dataset – Dataset used for sampling.
complexity_fn (Callable) – A function whose input is a sample and output is an integer as a measure of the computational complexity of the sample.
num_replicas (int, optional) – Number of processes participating in distributed training. By default,
world_size
is retrieved from the current distributed group.rank (int, optional) – Rank of the current process within
num_replicas
. By default,rank
is retrieved from the current distributed group.shuffle (bool, optional) – If
True
(default), sampler will shuffle the indices.seed (int, optional) – random seed used to shuffle the sampler if
shuffle=True
. This number should be identical across all processes in the distributed group. Default:0
.drop_last (bool, optional) – if
True
, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. IfFalse
, the sampler will add extra indices to make the data evenly divisible across the replicas. Default:False
.random_level (float, optional) – A float varies from
0
and1
that controls the extent of load balance.0
means the best load balance, while1
means the opposite.
Warning
In distributed mode, calling the
set_epoch
method at the beginning of each epoch before creating theDataLoader
iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.- Example::
Define your complexity_fn, which accepts a dataset sample as its input and produces an integer as the sample’s computational complexity.
>>> dataset = torch.utils.data.TensorDataset(torch.randn(n, 2), torch.randperm(n)) >>> complexity_fn = lambda x: x[1]
Below is the usage of
LoadBalancingDistributedSampler
andDataLoader
:>>> sampler = bagua.torch_api.contrib.data.LoadBalancingDistributedSampler( ... dataset, ... complexity_fn=complexity_fn) if is_distributed else None >>> loader = torch.utils.data.DataLoader(dataset, ... shuffle=(sampler is None), ... sampler=sampler) >>> >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)
- set_epoch(self, epoch)¶
Sets the epoch for this sampler. When
shuffle=True
, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.- Parameters
epoch (int) – Epoch number.
- Return type
None
- class bagua.torch_api.contrib.LoadBalancingDistributedBatchSampler(sampler, batch_fn, drop_last=False)¶
Bases:
torch.utils.data.sampler.Sampler
Wraps another load balance sampler to yield variable sized mini-batches.
- Parameters
sampler (LoadBalancingDistributedSampler) – Load balance sampler.
batch_fn (Callable) – Callable to yield mini-batch indices.
drop_last (bool) – If
True
, the sampler will drop the last few batches exceeding the least number of batches among replicas, otherwises, the number of batches on each replica will be padded to the same.
batch_fn will have the signature of
def batch_fn(indices: List[int]) -> List[List[int]]
.Example:
>>> from bagua.torch_api.contrib.data import LoadBalancingDistributedSampler, \ ... LoadBalancingDistributedBatchSampler >>> >>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn) >>> batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn) >>> loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler) >>> >>> for epoch in range(start_epoch, n_epochs): ... batch_sampler.set_epoch(epoch) ... train(loader)
- set_epoch(self, epoch)¶
Sets the epoch for this sampler. When
shuffle=True
, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.- Parameters
epoch (int) – Epoch number.
- Return type
None