bagua.torch_api.algorithms.q_adam¶
Module Contents¶
- class bagua.torch_api.algorithms.q_adam.QAdamAlgorithm(q_adam_optimizer, hierarchical=True)¶
Bases:
bagua.torch_api.algorithms.Algorithm
Create an instance of the QAdam Algorithm .
- Parameters
q_adam_optimizer – A QAdamOptimizer initialized with model parameters.
hierarchical – Enable hierarchical communication.
- init_backward_hook(self, bagua_module)¶
- Parameters
bagua_module (bagua.torch_api.distributed.BaguaModule) –
- init_operations(self, bagua_module, bucket)¶
- Parameters
bagua_module (bagua.torch_api.distributed.BaguaModule) –
bucket (bagua.torch_api.bucket.BaguaBucket) –
- init_tensors(self, bagua_module)¶
- Parameters
bagua_module (bagua.torch_api.distributed.BaguaModule) –
- need_reset(self)¶
- tensors_to_buckets(self, tensors)¶
- Parameters
tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) –
- Return type
- class bagua.torch_api.algorithms.q_adam.QAdamOptimizer(params, lr=0.001, warmup_steps=100, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0)¶
Bases:
torch.optim.optimizer.Optimizer
Create a dedicated optimizer used for QAdam algorithm.
- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
lr – learning rate (default: 1e-3)
warmup_steps – number of steps to do warm up in the begining of training.
betas – coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
eps – term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay – weight decay (L2 penalty) (default: 0.)
- step(self, closure=None)¶