bagua.torch_api.algorithms.async_model_average¶
Module Contents¶
- class bagua.torch_api.algorithms.async_model_average.AsyncModelAverageAlgorithm(peer_selection_mode='all', sync_interval_ms=500, warmup_steps=0)¶
Bases:
bagua.torch_api.algorithms.Algorithm
Create an instance of the AsyncModelAverage algorithm.
The asynchronous implementation is experimental, and imposes some restrictions. With such asynchronous algorithm, the number of iterations on each worker are different. Therefore the current implementation assumes that the dataset is an endless stream, and all workers continuously synchronize between each other.
Users should call
abort
to manually stop the algorithm’s continuous synchronization process.- Parameters
peer_selection_mode (str) – The way how workers communicate with each other. Currently
"all"
is supported."all"
means all workers’ weights are synchronized during each communication.sync_interval_ms (int) – Number of milliseconds between model synchronizations.
warmup_steps (int) – Number of steps to warm up by doing gradient allreduce before doing asynchronous model averaging. Use 0 to disable.
- abort(self, bagua_module)¶
Stop background asynchronous communications. Should be called after training.
- Parameters
bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by
with_bagua
method.
- resume(self, bagua_module)¶
Resume aborted background asynchronous communications (see
abort
). Should be called before training.- Parameters
bagua_module (bagua.torch_api.distributed.BaguaModule) – A PyTorch module initialized by
with_bagua
method.