bagua.torch_api.algorithms.bytegrad¶
Module Contents¶
- class bagua.torch_api.algorithms.bytegrad.ByteGradAlgorithm(average=True)¶
Bases:
bagua.torch_api.algorithms.Algorithm
Create an instance of the ByteGrad algorithm.
- Parameters
average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.
- init_operations(self, bagua_module, bucket)¶
- Parameters
bagua_module (bagua.torch_api.distributed.BaguaModule) –
bucket (bagua.torch_api.bucket.BaguaBucket) –
- tensors_to_buckets(self, tensors)¶
Given the bucketing suggestion from Bagua, return the actual Bagua buckets. The default implementation follows the suggestion to do the bucketing.
- Parameters
tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) – Bagua tensors grouped in different lists, representing Bagua’s suggestion on how to bucketing the tensors.
- Returns
A list of Bagua buckets.
- Return type