bagua.torch_api.checkpoint¶
Submodules¶
Package Contents¶
- bagua.torch_api.checkpoint.load_checkpoint(checkpoints_path, model, optimizer=None, lr_scheduler=None, strict=True)¶
Load a model checkpoint and return the iteration.
- Parameters:
checkpoints_path (str) – Path of checkpoints.
model (BaguaModule) – The model to load on.
optimizer (torch.optim.Optimizer, optional) – The optimizer to load on. Default:
None
.lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – The LR scheduler to load on. Default:
None
.strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
of the checkpoint match the keys returned by this module’s state_dict() function. Default:True
.
- bagua.torch_api.checkpoint.save_checkpoint(iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None)¶
Save model checkpoint.
- Parameters:
iteration (int) – Training Iteration.
checkpoints_path (str) – Path of checkpoints.
model (BaguaModule) – The model to save.
optimizer (torch.optim.Optimizer, optional) – The optimizer to save. Default:
None
.lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – The LR scheduler to save. Default:
None
.