bagua.torch_api.contrib.sync_batchnorm¶
Module Contents¶
- class bagua.torch_api.contrib.sync_batchnorm.SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)¶
Bases:
torch.nn.modules.batchnorm._BatchNorm
Applies synchronous BatchNorm for distributed module with N-dimensional BatchNorm layer(s). See BatchNorm for more details.
- Parameters:
num_features – Number of channels \(C\) from the shape \((N, C, ...)\).
eps – A value added to the denominator for numerical stability. Default: 1e-5.
momentum – The value used for the running_mean and running_var computation. Can be set to
None
for cumulative moving average (i.e. simple average). Default: 0.1.affine – A boolean value that when set to
True
, this module has learnable affine parameters. Default:True
.track_running_stats – A boolean value that when set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default:True
.
Note
Only GPU input tensors are supported in the training mode.
- classmethod convert_sync_batchnorm(module)¶
Helper function to convert all
BatchNorm*D
layers in the model to torch.nn.SyncBatchNorm layers.- Parameters:
module (nn.Module) – Module containing one or more
BatchNorm*D
layers- Returns:
The original
module
with the convertedtorch.nn.SyncBatchNorm
layers. If the originalmodule
is aBatchNorm*D
layer, a newtorch.nn.SyncBatchNorm
layer object will be returned instead.
Note
This function must be called before
with_bagua
method.- Example::
>>> # Network with nn.BatchNorm layer >>> model = torch.nn.Sequential( ... torch.nn.Linear(D_in, H), ... torch.nn.ReLU(), ... torch.nn.Linear(H, D_out), ... ) >>> optimizer = torch.optim.SGD( ... model.parameters(), ... lr=0.01, ... momentum=0.9 ... ) >>> sync_bn_model = bagua.torch_api.contrib.sync_batchnorm.SyncBatchNorm.convert_sync_batchnorm(model) >>> bagua_model = sync_bn_model.with_bagua([optimizer], GradientAllReduce())
- forward(input)¶