bagua.torch_api.contrib

Subpackages

Submodules

Package Contents

class bagua.torch_api.contrib.CacheLoader(backend='redis', dataset_name='', writer_buffer_size=1, **kwargs)

Cache loader caches values calculated by an expensive function by theirs keys via get, so that the values can be retrieved faster next time.

Internally, values are indexed by "{dataset_name}_{key}" and saved in a distributed key-value store, where dataset_name is specified on initializing, and key is the argument in get.

By default, cache loader uses RedisStore as its backend distributed key-value store implementation. It supports using a list of existing redis servers or spawning new redis servers. Parameters for RedisStore can be provided here in **kwargs.

Parameters
  • backend (str) – Backend distributed key-value store implementation. Can be "redis".

  • dataset_name (str) – Name of the dataset. Default "".

  • writer_buffer_size (int) – Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput.

Example::

To use a list of existing redis servers for the “redis” backend:

>>> from bagua.torch_api.contrib import CacheLoader
>>>
>>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}]
>>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, dataset_name="test")
>>>
>>> loader.get(index, lambda x: items[x])

To spawn new redis servers on training nodes for the “redis” backend, each node with a maximum memory limit of 100000000 bytes:

>>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000)

Note

Cache loaders with the same dataset_name will reuse and overwrite each other’s cache. Use a different dataset_name if this is not desired.

get(key, load_fn)

Returns the value associated with key in cache, use load_fn to create the entry if the key does not exist in the cache. load_fn is a function taking key as its argument, and returning corresponding value to be cached.

Parameters
  • key (str) –

  • load_fn (Callable[[str], None]) –

num_keys()

Returns the number of keys in the cache.

class bagua.torch_api.contrib.CachedDataset(dataset, backend='redis', dataset_name='', writer_buffer_size=20, **kwargs)

Bases: torch.utils.data.dataset.Dataset

Cached dataset wraps a PyTorch dataset to cache its samples in memory, so that accessing these samples after the first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading the dataset itself is slow, which could slow down the whole training process.

Internally, the samples are indexed by a string key "{dataset_name}_{index}" and saved in a distributed key-value store, where dataset_name is specified when initializing the cached dataset, and index is the index of a specific sample (the argument of __getitem__ method in a PyTorch dataset).

Parameters
  • dataset (torch.utils.data.dataset.Dataset) – PyTorch dataset to be wrapped.

  • backend (str) – Backend distributed key-value store implementation. Can be "redis".

  • dataset_name (str) – Name of the dataset. Default "".

  • writer_buffer_size (int) – Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput.

Example:

>>> from bagua.torch_api.contrib import CachedDataset
>>> cache_dataset = CachedDataset(dataset, backend="redis", dataset_name="ds")
>>> dataloader = torch.utils.data.DataLoader(cached_dataset)

Note

Cached dataset is a special case of cache loader. Parameter backend and writer_buffer_size in initializing a cached dataset have the same meanings as those in initializing a cache loader. You can provide the arguments for cache loader here in **kwargs. See also CacheLoader.

cache_loader

The backend cache loader instance.

class bagua.torch_api.contrib.LoadBalancingDistributedBatchSampler(sampler, batch_fn, drop_last=False)

Bases: torch.utils.data.sampler.Sampler

Wraps another load balance sampler to yield variable sized mini-batches.

Parameters
  • sampler (LoadBalancingDistributedSampler) – Load balance sampler.

  • batch_fn (Callable) – Callable to yield mini-batch indices.

  • drop_last (bool) – If True, the sampler will drop the last few batches exceeding the least number of batches among replicas, otherwise, the number of batches on each replica will be padded to the same.

batch_fn will have the signature of:

def batch_fn(indices: List[int]) -> List[List[int]]
Example::
>>> from bagua.torch_api.contrib import LoadBalancingDistributedSampler, \
...     LoadBalancingDistributedBatchSampler
>>>
>>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
>>> batch_sampler = LoadBalancingDistributedBatchSampler(sampler, batch_fn=batch_fn)
>>> loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
...     batch_sampler.set_epoch(epoch)
...     train(loader)
set_epoch(epoch)

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

class bagua.torch_api.contrib.LoadBalancingDistributedSampler(dataset, complexity_fn, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, random_level=0)

Bases: torch.utils.data.sampler.Sampler

Sampler that restricts data loading to a subset of the dataset.

This sampler use a complexity_fn to calculate each sample’s computational complexity and make each batch get similar computation complexity.

This is useful in scenarios like speech and NLP, where each batch has variable length and distributed training suffers from straggler problem.

The usage is similar to torch.utils.data.DistributedSampler, where each process loads a subset of the original dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size.

Parameters
  • dataset (torch.utils.data.dataset.Dataset) – Dataset used for sampling.

  • complexity_fn (Callable) – A function whose input is a sample and output is an integer as a measure of the computational complexity of the sample.

  • num_replicas (int, optional) – Number of processes participating in distributed training. By default, world_size is retrieved from the current distributed group.

  • rank (int, optional) – Rank of the current process within num_replicas. By default, rank is retrieved from the current distributed group.

  • shuffle (bool, optional) – If True (default), sampler will shuffle the indices.

  • seed (int, optional) – random seed used to shuffle the sampler if shuffle=True. This number should be identical across all processes in the distributed group. Default: 0.

  • drop_last (bool, optional) – if True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: False.

  • random_level (float, optional) – A float varies from 0 and 1 that controls the extent of load balance. 0 means the best load balance, while 1 means the opposite.

Warning

In distributed mode, calling the set_epoch method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

Example::

Define your complexity_fn, which accepts a dataset sample as its input and produces an integer as the sample’s computational complexity:

>>> dataset = torch.utils.data.TensorDataset(torch.randn(n, 2), torch.randperm(n))
>>> complexity_fn = lambda x: x[1]

Below is the usage of LoadBalancingDistributedSampler and DataLoader:

>>> sampler = bagua.torch_api.contrib.LoadBalancingDistributedSampler(
...     dataset,
...     complexity_fn=complexity_fn) if is_distributed else None
>>> loader = torch.utils.data.DataLoader(dataset,
...     shuffle=(sampler is None),
...     sampler=sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)
set_epoch(epoch)

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True, check_flatten=True)

Convert any optimizer into a fused optimizer.

A fused optimizer can fuse multiple parameter updates into one or a few updates. To achieve this, users need to:

1) flatten multiple parameters in the same group into fused parameter by setting do_flatten=True, which is also the default behavior of a fused optimizer;
2) perform a fused parameter update by calling fuse_step.

This fused optimizer is implemented for general use. It can be used used in conjunction with a BaguaModule as well as a torch.nn.parallel.DistributedDataParallel wrapped module, or some other cases (not listed here).

Parameters
  • optimizer (torch.optim.Optimizer) – Any PyTorch optimizer.

  • do_flatten (bool) – Whether to flatten the parameters. The flatten operation will reset data pointers of parameter tensors so that they can be fused together. Default: True.

  • check_flatten (bool) – When setting to True, it enables fused optimizer to automatically check if parameter tensors are contiguous as they are flattened to. Can only work with do_flatten=True. Default: True.

Returns

A Fused optimizer.

Example::
>>> optimizer = torch.optim.Adadelta(model.parameters(), ....)
>>> optimizer = bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True)
>>>
>>> optimizer.fuse_step()

When use in conjunction with a BaguaModule, set do_flatten=False in with_bagua explicitly:

>>> optimizer = bagua.torch_api.contrib.fuse_optimizer(optimizer, do_flatten=True)
>>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm(), do_flatten=False)
>>>
>>> optimizer.fuse_step()

Note

This function and with_bagua method both will reset data pointers of module parameters by default. In order to perform a more effective fused parameter update, users need to disable bucket flattening in with_bagua by setting its do_flatten to False.

Note

A fuse optimizer does not change the original behaviors of optimizer, but enabling it to perform a fused parameter update through fuse_step. Users can still perform a normal parameter update through step.