bagua.torch_api.algorithms.q_adam

Module Contents

class bagua.torch_api.algorithms.q_adam.QAdamAlgorithm(q_adam_optimizer, hierarchical=True)

Bases: bagua.torch_api.algorithms.Algorithm

This is the base class that all Bagua algorithms inherit.

Create an instance of the QAdam Algorithm .

Parameters:
  • q_adam_optimizer (QAdamOptimizer) – A QAdamOptimizer initialized with model parameters.

  • hierarchical (bool) – Enable hierarchical communication.

class bagua.torch_api.algorithms.q_adam.QAdamAlgorithmImpl(process_group, q_adam_optimizer, hierarchical=True)

Bases: bagua.torch_api.algorithms.AlgorithmImpl

This is the base class that all Bagua algorithm implementations inherit.

It provides methods that can be override to implement different kinds of distributed algorithms.

Parameters:

Implementation of the QAdam Algorithm .

Parameters:
class bagua.torch_api.algorithms.q_adam.QAdamOptimizer(params, lr=0.001, warmup_steps=100, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

Bases: torch.optim.optimizer.Optimizer

Create a dedicated optimizer used for QAdam algorithm.

Parameters:
  • params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float) – Learning rate.

  • warmup_steps (int) – Number of steps to warm up by doing gradient allreduce before doing asynchronous model averaging. Use 0 to disable.

  • betas (Tuple[float, float]) – Coefficients used for computing running averages of gradient and its square.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • weight_decay (float) – Weight decay (L2 penalty).

step(closure=None)