bagua.torch_api.contrib.cached_dataset¶
Module Contents¶
- class bagua.torch_api.contrib.cached_dataset.CachedDataset(dataset, backend='redis', dataset_name='', writer_buffer_size=20, **kwargs)¶
Bases:
torch.utils.data.dataset.Dataset
Cached dataset wraps a PyTorch dataset to cache its samples in memory, so that accessing these samples after the first time can be much faster. This is useful when samples need tedious preprocessing to produce, or reading the dataset itself is slow, which could slow down the whole training process.
Internally, the samples are indexed by a string key
"{dataset_name}_{index}"
and saved in a distributed key-value store, wheredataset_name
is specified when initializing the cached dataset, andindex
is the index of a specific sample (the argument of__getitem__
method in a PyTorch dataset).- Parameters
dataset (torch.utils.data.dataset.Dataset) – PyTorch dataset to be wrapped.
backend (str) – Backend distributed key-value store implementation. Can be
"redis"
.dataset_name (str) – Name of the dataset. Default
""
.writer_buffer_size (int) – Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput.
Example:
>>> from bagua.torch_api.contrib import CachedDataset >>> cache_dataset = CachedDataset(dataset, backend="redis", dataset_name="ds") >>> dataloader = torch.utils.data.DataLoader(cached_dataset)
Note
Cached dataset is a special case of cache loader. Parameter
backend
andwriter_buffer_size
in initializing a cached dataset have the same meanings as those in initializing a cache loader. You can provide the arguments for cache loader here in**kwargs
. See alsoCacheLoader
.- cache_loader¶
The backend cache loader instance.