bagua.torch_api.algorithms.q_adam

Module Contents

class bagua.torch_api.algorithms.q_adam.QAdamAlgorithm(q_adam_optimizer, hierarchical=True)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the QAdam Algorithm .

Parameters
  • q_adam_optimizer – A QAdamOptimizer initialized with model parameters.

  • hierarchical – Enable hierarchical communication.

init_backward_hook(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

init_operations(self, bagua_module, bucket)
Parameters
init_tensors(self, bagua_module)
Parameters

bagua_module (bagua.torch_api.distributed.BaguaModule) –

need_reset(self)
tensors_to_buckets(self, tensors)
Parameters

tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) –

Return type

List[bagua.torch_api.bucket.BaguaBucket]

class bagua.torch_api.algorithms.q_adam.QAdamOptimizer(params, lr=0.001, warmup_steps=100, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)

Bases: torch.optim.optimizer.Optimizer

Create a dedicated optimizer used for QAdam algorithm.

Parameters
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

  • lr – learning rate (default: 1e-3)

  • warmup_steps – number of steps to do warm up in the begining of training.

  • betas – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))

  • eps – term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay – weight decay (L2 penalty) (default: 0.)

step(self, closure=None)