diff --git a/train.py b/train.py index 1eda27a4..c0690d73 100644 --- a/train.py +++ b/train.py @@ -448,7 +448,7 @@ def burnin_schedule(i): pass save_path = os.path.join(config.checkpoints, f'{save_prefix}{epoch + 1}.pth') if isinstance(model, torch.nn.DataParallel): - torch.save(model.moduel,state_dict(), save_path) + torch.save(model.module.state_dict(), save_path) else: torch.save(model.state_dict(), save_path) logging.info(f'Checkpoint {epoch + 1} saved !')