bagua.torch_api.algorithms.q_adam

Module Contents

class bagua.torch_api.algorithms.q_adam.QAdamAlgorithm(q_adam_optimizer, hierarchical=True)

Bases: bagua.torch_api.algorithms.Algorithm

This is the base class that all Bagua algorithms inherit.

Create an instance of the QAdam Algorithm .

Parameters
  • q_adam_optimizer (QAdamOptimizer) – A QAdamOptimizer initialized with model parameters.

  • hierarchical (bool) – Enable hierarchical communication.

class bagua.torch_api.algorithms.q_adam.QAdamAlgorithmImpl(process_group, q_adam_optimizer, hierarchical=True)

Bases: bagua.torch_api.algorithms.AlgorithmImpl

This is the base class that all Bagua algorithm implementations inherit.

It provides methods that can be override to implement different kinds of distributed algorithms.

Parameters

Implementation of the QAdam Algorithm .

Parameters
property optimizer_step_id(self)
class bagua.torch_api.algorithms.q_adam.QAdamOptimizer(params, lr=0.001, warmup_steps=100, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

Bases: torch.optim.optimizer.Optimizer

Create a dedicated optimizer used for QAdam algorithm.

Parameters
  • params (iterable) – Iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float) – Learning rate.

  • warmup_steps (int) – Number of steps to warm up by doing gradient allreduce before doing asynchronous model averaging. Use 0 to disable.

  • betas (Tuple[float, float]) – Coefficients used for computing running averages of gradient and its square.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • weight_decay (float) – Weight decay (L2 penalty).

step(self, closure=None)