bagua.torch_api.contrib¶
Subpackages¶
Submodules¶
Package Contents¶
- class bagua.torch_api.contrib.CacheLoader(backend='redis', dataset_name='', writer_buffer_size=1, **kwargs)¶
Cache loader caches values calculated by an expensive function by theirs keys via
get
, so that the values can be retrieved faster next time.Internally, values are indexed by
"{dataset_name}_{key}"
and saved in a distributed key-value store, wheredataset_name
is specified on initializing, andkey
is the argument inget
.By default, cache loader uses
RedisStore
as its backend distributed key-value store implementation. It supports using a list of existing redis servers or spawning new redis servers. Parameters forRedisStore
can be provided here in**kwargs
.- Parameters:
backend (str) – Backend distributed key-value store implementation. Can be
"redis"
.dataset_name (str) – Name of the dataset. Default
""
.writer_buffer_size (int) – Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput.
- Example::
To use a list of existing redis servers for the “redis” backend:
>>> from bagua.torch_api.contrib import CacheLoader >>> >>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}] >>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, dataset_name="test") >>> >>> loader.get(index, lambda x: items[x])
To spawn new redis servers on training nodes for the “redis” backend, each node with a maximum memory limit of 100000000 bytes:
>>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000)
Note
Cache loaders with the same
dataset_name
will reuse and overwrite each other’s cache. Use a differentdataset_name
if this is not desired.- get(key, load_fn)¶
Returns the value associated with
key
in cache, useload_fn
to create the entry if the key does not exist in the cache.load_fn
is a function takingkey
as its argument, and returning corresponding value to be cached.- Parameters:
key (str) –
load_fn (Callable[[str], None]) –
- num_keys()¶
Returns the number of keys in the cache.
- class bagua.torch_api.contrib.CachedDataset(dataset, backend='redis', dataset_name='', writer_buffer_size=20, **kwargs)¶
Bases:
torch.utils.data.dataset.Dataset
Cached dataset wraps a PyTorch dataset to cache its samples in memory, so that accessing these samples after the first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading the dataset itself is slow, which could slow down the whole training process.
Internally, the samples are indexed by a string key
"{dataset_name}_{index}"
and saved in a distributed key-value store, wheredataset_name
is specified when initializing the cached dataset, andindex
is the index of a specific sample (the argument of__getitem__
method in a PyTorch dataset).- Parameters:
dataset (torch.utils.data.dataset.Dataset) – PyTorch dataset to be wrapped.
backend (str) – Backend distributed key-value store implementation. Can be
"redis"
.dataset_name (str) – Name of the dataset. Default
""
.writer_buffer_size (int) – Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput.
Example:
>>> from bagua.torch_api.contrib import CachedDataset >>> cache_dataset = CachedDataset(dataset, backend="redis", dataset_name="ds") >>> dataloader = torch.utils.data.DataLoader(cached_dataset)
Note
Cached dataset is a special case of cache loader. Parameter
backend
andwriter_buffer_size
in initializing a cached dataset have the same meanings as those in initializing a cache loader. You can provide the arguments for cache loader here in**kwargs
. See alsoCacheLoader
.- cache_loader¶
The backend cache loader instance.
- class bagua.torch_api.contrib.LoadBalancingDistributedBatchSampler(sampler, batch_fn, drop_last=False)¶
Bases:
torch.utils.data.sampler.Sampler
Wraps another load balance sampler to yield variable sized mini-batches.
- Parameters:
sampler (LoadBalancingDistributedSampler) – Load balance sampler.
batch_fn (Callable) – Callable to yield mini-batch indices.
drop_last (bool) – If
True
, the sampler will drop the last few batches exceeding the least number of batches among replicas, otherwise, the number of batches on each replica will be padded to the same.
batch_fn
will have the signature of:def batch_fn(indices: List[int]) -> List[List[int]]
- Example::
>>> from bagua.torch_api.contrib import LoadBalancingDistributedSampler, \ ... LoadBalancingDistributedBatchSampler >>> >>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn) >>> batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn) >>> loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler) >>> >>> for epoch in range(start_epoch, n_epochs): ... batch_sampler.set_epoch(epoch) ... train(loader)
- set_epoch(epoch)¶
Sets the epoch for this sampler. When
shuffle=True
, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.- Parameters:
epoch (int) – Epoch number.
- Return type:
None
- class bagua.torch_api.contrib.LoadBalancingDistributedSampler(dataset, complexity_fn, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, random_level=0)¶
Bases:
torch.utils.data.sampler.Sampler
Sampler that restricts data loading to a subset of the dataset.
This sampler use a
complexity_fn
to calculate each sample’s computational complexity and make each batch get similar computation complexity.This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from straggler problem.
The usage is similar to torch.utils.data.DistributedSampler, where each process loads a subset of the original dataset that is exclusive to it.
Note
Dataset is assumed to be of constant size.
- Parameters:
dataset (torch.utils.data.dataset.Dataset) – Dataset used for sampling.
complexity_fn (Callable) – A function whose input is a sample and output is an integer as a measure of the computational complexity of the sample.
num_replicas (int, optional) – Number of processes participating in distributed training. By default,
world_size
is retrieved from the current distributed group.rank (int, optional) – Rank of the current process within
num_replicas
. By default,rank
is retrieved from the current distributed group.shuffle (bool, optional) – If
True
(default), sampler will shuffle the indices.seed (int, optional) – random seed used to shuffle the sampler if
shuffle=True
. This number should be identical across all processes in the distributed group. Default: 0.drop_last (bool, optional) – if
True
, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. IfFalse
, the sampler will add extra indices to make the data evenly divisible across the replicas. Default:False
.random_level (float, optional) – A float varies from 0 and 1 that controls the extent of load balance. 0 means the best load balance, while 1 means the opposite.
Warning
In distributed mode, calling the
set_epoch
method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.- Example::
Define your
complexity_fn
, which accepts a dataset sample as its input and produces an integer as the sample’s computational complexity:>>> dataset = torch.utils.data.TensorDataset(torch.randn(n, 2), torch.randperm(n)) >>> complexity_fn = lambda x: x[1]
Below is the usage of
LoadBalancingDistributedSampler
and DataLoader:>>> sampler = bagua.torch_api.contrib.LoadBalancingDistributedSampler( ... dataset, ... complexity_fn=complexity_fn) if is_distributed else None >>> loader = torch.utils.data.DataLoader(dataset, ... shuffle=(sampler is None), ... sampler=sampler) >>> >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)
- set_epoch(epoch)¶
Sets the epoch for this sampler. When
shuffle=True
, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.- Parameters:
epoch (int) – Epoch number.
- Return type:
None
- bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True, check_flatten=True)¶
Convert any optimizer into a fused optimizer.
A fused optimizer can fuse multiple parameter updates into one or a few updates. To achieve this, users need to:
1) flatten multiple parameters in the same group into fused parameter by settingdo_flatten=True
, which is also the default behavior of a fused optimizer;2) perform a fused parameter update by callingfuse_step
.This fused optimizer is implemented for general use. It can be used used in conjunction with a
BaguaModule
as well as a torch.nn.parallel.DistributedDataParallel wrapped module, or some other cases (not listed here).- Parameters:
optimizer (torch.optim.Optimizer) – Any PyTorch optimizer.
do_flatten (bool) – Whether to flatten the parameters. The flatten operation will reset data pointers of parameter tensors so that they can be fused together. Default:
True
.check_flatten (bool) – When setting to
True
, it enables fused optimizer to automatically check if parameter tensors are contiguous as they are flattened to. Can only work withdo_flatten=True
. Default:True
.
- Returns:
A Fused optimizer.
- Example::
>>> optimizer = torch.optim.Adadelta(model.parameters(), ....) >>> optimizer = bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True) >>> >>> optimizer.fuse_step()
When use in conjunction with a
BaguaModule
, setdo_flatten=False
inwith_bagua
explicitly:>>> optimizer = bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True) >>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm(), do_flatten=False) >>> >>> optimizer.fuse_step()
Note
This function and
with_bagua
method both will reset data pointers of module parameters by default. In order to perform a more effective fused parameter update, users need to disable bucket flattening inwith_bagua
by setting itsdo_flatten
toFalse
.Note
A fuse optimizer does not change the original behaviors of
optimizer
, but enabling it to perform a fused parameter update throughfuse_step
. Users can still perform a normal parameter update throughstep
.