bagua.torch_api.algorithms.q_adam¶
Module Contents¶
- class bagua.torch_api.algorithms.q_adam.QAdamAlgorithm(q_adam_optimizer, hierarchical=True)¶
Bases:
bagua.torch_api.algorithms.Algorithm
Create an instance of the QAdam Algorithm .
- Parameters
q_adam_optimizer (QAdamOptimizer) – A QAdamOptimizer initialized with model parameters.
hierarchical (bool) – Enable hierarchical communication.
- class bagua.torch_api.algorithms.q_adam.QAdamOptimizer(params, lr=0.001, warmup_steps=100, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)¶
Bases:
torch.optim.optimizer.Optimizer
Create a dedicated optimizer used for QAdam algorithm.
- Parameters
params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (float) – Learning rate.
warmup_steps (int) – Number of steps to warm up by doing gradient allreduce before doing asynchronous model averaging. Use 0 to disable.
betas (Tuple[float, float]) – Coefficients used for computing running averages of gradient and its square.
eps (float) – Term added to the denominator to improve numerical stability.
weight_decay (float) – Weight decay (L2 penalty).
- step(self, closure=None)¶