bagua.torch_api.algorithms.bytegrad

Module Contents

class bagua.torch_api.algorithms.bytegrad.ByteGradAlgorithm(average=True)

Bases: bagua.torch_api.algorithms.Algorithm

Create an instance of the ByteGrad algorithm.

Parameters

average (bool) – If True, the gradients on each worker are averaged. Otherwise, they are summed.

init_operations(self, bagua_module, bucket)
Parameters
tensors_to_buckets(self, tensors)

Given the bucketing suggestion from Bagua, return the actual Bagua buckets. The default implementation follows the suggestion to do the bucketing.

Parameters

tensors (List[List[bagua.torch_api.tensor.BaguaTensor]]) – Bagua tensors grouped in different lists, representing Bagua’s suggestion on how to bucketing the tensors.

Returns

A list of Bagua buckets.

Return type

List[bagua.torch_api.bucket.BaguaBucket]