bagua.torch_api.algorithms.async_model_average

Module Contents

class bagua.torch_api.algorithms.async_model_average.AsyncModelAverageAlgorithm(peer_selection_mode='all', sync_interval_ms=500, warmup_steps=0)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the AsyncModelAverage algorithm.

The asynchronous implementation is experimental, and imposes some restrictions. With such asynchronous algorithm, the number of iterations on each worker are different. Therefore the current implementation assumes that the dataset is an endless stream, and all workers continuously synchronize between each other.

Users should call abort to manually stop the algorithm’s continuous synchronization process. For example, for a model wrapped with .with_bagua(…), you can abort with model.bagua_algorithm.abort(model), and resume with model.bagua_algorithm.resume(model).

Parameters
  • peer_selection_mode (str) – The way how workers communicate with each other. Currently "all" is supported. "all" means all workers’ weights are synchronized during each communication.

  • sync_interval_ms (int) – Number of milliseconds between model synchronizations.

  • warmup_steps (int) – Number of steps to warm up by doing gradient allreduce before doing asynchronous model averaging. Use 0 to disable.

class bagua.torch_api.algorithms.async_model_average.AsyncModelAverageAlgorithmImpl(process_group, peer_selection_mode='all', sync_interval_ms=500, warmup_steps=0)

Bases: bagua.torch_api.algorithms.AlgorithmImpl

Implementation of the AsyncModelAverage algorithm.

The asynchronous implementation is experimental, and imposes some restrictions. With such asynchronous algorithm, the number of iterations on each worker are different. Therefore the current implementation assumes that the dataset is an endless stream, and all workers continuously synchronize between each other.

Users should call abort to manually stop the algorithm’s continuous synchronization process. For example, for a model wrapped with .with_bagua(…), you can abort with model.bagua_algorithm.abort(model), and resume with model.bagua_algorithm.resume(model).

Parameters
  • process_group (BaguaProcessGroup) – The process group to work on.

  • peer_selection_mode (str) – The way how workers communicate with each other. Currently "all" is supported. "all" means all workers’ weights are synchronized during each communication.

  • sync_interval_ms (int) – Number of milliseconds between model synchronizations.

  • warmup_steps (int) – Number of steps to warm up by doing gradient allreduce before doing asynchronous model averaging. Use 0 to disable.

abort(self, bagua_module)

Stop background asynchronous communications. Should be called after training.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua method.

resume(self, bagua_module)

Resume aborted background asynchronous communications (see abort). Should be called before training.

Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by with_bagua method.