bagua.torch_api.contrib.cache_loader

Module Contents

class bagua.torch_api.contrib.cache_loader.CacheLoader(backend='redis', dataset_name='', writer_buffer_size=1, **kwargs)

Cache loader caches values calculated by an expensive function by theirs keys via get, so that the values can be retrieved faster next time.

Internally, values are indexed by "{dataset_name}_{key}" and saved in a distributed key-value store, where dataset_name is specified on initializing, and key is the argument in get.

By default, cache loader uses RedisStore as its backend distributed key-value store implementation. It supports using a list of existing redis servers or spawning new redis servers. Parameters for RedisStore can be provided here in **kwargs.

Parameters
  • backend (str) – Backend distributed key-value store implementation. Can be "redis".

  • dataset_name (str) – Name of the dataset. Default "".

  • writer_buffer_size (int) – Number of samples to collect before writing to the backend key-value store. Useful for improving the backend throughput.

Example::

To use a list of existing redis servers for the “redis” backend:

>>> from bagua.torch_api.contrib import CacheLoader
>>>
>>> hosts = [{"host": "192.168.1.0", "port": "7000"}, {"host": "192.168.1.1", "port": "7000"}]
>>> loader = CacheLoader(backend="redis", hosts=hosts, cluster_mode=True, dataset_name="test")
>>>
>>> loader.get(index, lambda x: items[x])

To spawn new redis servers on training nodes for the “redis” backend, each node with a maximum memory limit of 100000000 bytes:

>>> loader = CacheLoader(backend="redis", hosts=None, cluster_mode=True, capacity_per_node=100000000)

Note

Cache loaders with the same dataset_name will reuse and overwrite each other’s cache. Use a different dataset_name if this is not desired.

get(key, load_fn)

Returns the value associated with key in cache, use load_fn to create the entry if the key does not exist in the cache. load_fn is a function taking key as its argument, and returning corresponding value to be cached.

Parameters
  • key (str) –

  • load_fn (Callable[[str], None]) –

num_keys()

Returns the number of keys in the cache.