- class bagua.torch_api.algorithms.async_model_average.AsyncModelAverageAlgorithm(peer_selection_mode='all', sync_interval_ms=500, warmup_steps=0)¶
Create an instance of the AsyncModelAverage algorithm.
The asynchronous implementation is experimental, and imposes some restrictions. With such asynchronous algorithm, the number of iterations on each worker are different. Therefore the current implementation assumes that the dataset is an endless stream, and all workers continuously synchronize between each other.
Users should call
abortto manually stop the algorithm’s continuous synchronization process.
peer_selection_mode (str) – The way how workers communicate with each other. Currently
"all"means all workers’ weights are synchronized during each communication.
sync_interval_ms (int) – Number of milliseconds between model synchronizations.
warmup_steps (int) – Number of steps to warm up by doing gradient allreduce before doing asynchronous model averaging. Use 0 to disable.
- abort(self, bagua_module)¶
Stop background asynchronous communications. Should be called after training.