From 96fd12bb9601128615de18fe0041a3d742808a5c Mon Sep 17 00:00:00 2001 From: Abu Bakr Date: Fri, 17 Dec 2021 06:33:55 +0200 Subject: [PATCH] set new parameter to save the model to a specific path --- train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 7f3a66bf00..edf5855eb2 100644 --- a/train.py +++ b/train.py @@ -30,7 +30,7 @@ def train(opt): opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) - log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') + log = open(f'{opt.save_model_to}/{opt.exp_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( @@ -119,7 +119,7 @@ def train(opt): """ final options """ # print(opt) - with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: + with open(f'{opt.save_model_to}/{opt.exp_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): @@ -175,7 +175,7 @@ def train(opt): if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' elapsed_time = time.time() - start_time # for log - with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: + with open(f'{opt.save_model_to}/{opt.exp_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( @@ -191,10 +191,10 @@ def train(opt): # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy - torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') + torch.save(model.state_dict(), f'{opt.save_model_to}/{opt.exp_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED - torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') + torch.save(model.state_dict(), f'{opt.save_model_to}/{opt.exp_name}/best_norm_ED.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' @@ -218,7 +218,7 @@ def train(opt): # save model per 1e+5 iter. if (iteration + 1) % 1e+5 == 0: torch.save( - model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') + model.state_dict(), f'{opt.save_model_to}/{opt.exp_name}/iter_{iteration+1}.pth') if (iteration + 1) == opt.num_iter: print('end the training') @@ -237,6 +237,7 @@ def train(opt): parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') parser.add_argument('--saved_model', default='', help="path to model to continue training") + parser.add_argument('--save_model_to', default='./saved_model', help="path to save your new model") parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') @@ -281,7 +282,7 @@ def train(opt): opt.exp_name += f'-Seed{opt.manualSeed}' # print(opt.exp_name) - os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) + os.makedirs(f'{opt.save_model_to}/{opt.exp_name}', exist_ok=True) """ vocab / character number configuration """ if opt.sensitive: