bagua.torch_api.contrib.sync_batchnorm

Module Contents

class bagua.torch_api.contrib.sync_batchnorm.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

Bases: torch.nn.modules.batchnorm._BatchNorm

Applies synchronous BatchNorm for distributed module with N-dimensional BatchNorm layer(s). See BatchNorm for more details.

Parameters
  • num_features – Number of channels \(C\) from the shape \((N, C, ...)\).

  • eps – A value added to the denominator for numerical stability. Default: 1e-5.

  • momentum – The value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1.

  • affine – A boolean value that when set to True, this module has learnable affine parameters. Default: True.

  • track_running_stats – A boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: True.

Note

Only GPU input tensors are supported in the training mode.

classmethod convert_sync_batchnorm(module)

Helper function to convert all BatchNorm*D layers in the model to torch.nn.SyncBatchNorm layers.

Parameters

module (nn.Module) – Module containing one or more BatchNorm*D layers

Returns

The original module with the converted torch.nn.SyncBatchNorm layers. If the original module is a BatchNorm*D layer, a new torch.nn.SyncBatchNorm layer object will be returned instead.

Note

This function must be called before with_bagua method.

Example::
>>> # Network with nn.BatchNorm layer
>>> model = torch.nn.Sequential(
...      torch.nn.Linear(D_in, H),
...      torch.nn.ReLU(),
...      torch.nn.Linear(H, D_out),
...    )
>>> optimizer = torch.optim.SGD(
...      model.parameters(),
...      lr=0.01,
...      momentum=0.9
...    )
>>> sync_bn_model = bagua.torch_api.contrib.sync_batchnorm.SyncBatchNorm.convert_sync_batchnorm(model)
>>> bagua_model = sync_bn_model.with_bagua([optimizer], GradientAllReduce())
forward(input)
bagua.torch_api.contrib.sync_batchnorm.unused