bagua.torch_api.checkpoint.checkpointing

Module Contents

bagua.torch_api.checkpoint.checkpointing.load_checkpoint(checkpoints_path, model, optimizer=None, lr_scheduler=None, strict=True)

Load a model checkpoint and return the iteration.

Parameters:
  • checkpoints_path (str) – Path of checkpoints.

  • model (BaguaModule) – The model to load on.

  • optimizer (torch.optim.Optimizer, optional) – The optimizer to load on. Default: None.

  • lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – The LR scheduler to load on. Default: None.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict of the checkpoint match the keys returned by this module’s state_dict() function. Default: True.

bagua.torch_api.checkpoint.checkpointing.save_checkpoint(iteration, checkpoints_path, model, optimizer=None, lr_scheduler=None)

Save model checkpoint.

Parameters:
  • iteration (int) – Training Iteration.

  • checkpoints_path (str) – Path of checkpoints.

  • model (BaguaModule) – The model to save.

  • optimizer (torch.optim.Optimizer, optional) – The optimizer to save. Default: None.

  • lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional) – The LR scheduler to save. Default: None.