diff --git a/Makefile b/Makefile index 99bd1f2..d2f5a68 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,12 @@ build-all: .require-path run: .require-command bash -c "`cat ${ROOT}/${command}`" +results: .require-path + @@for p in $(shell ls ${path});do \ + echo "$$p-> `grep Validation ${path}/$$p/instance.log 2> /dev/null|tail -1`" | \ + sed -e "s/\\(.*\\)->.*(\\([0-9.]*\\) -->.*).*/\1,\2/"; \ + done + .require-config: ifndef config $(error config is required) diff --git a/experiments/base.py b/experiments/base.py index 45d5b8c..802c51c 100644 --- a/experiments/base.py +++ b/experiments/base.py @@ -21,6 +21,7 @@ class Experiment(ABC): def __init__(self, config_path: str): self.config_path = config_path self.root = Path(config_path).parent + self.isEval = False gin.parse_config_file(self.config_path) @gin.configurable() @@ -81,8 +82,10 @@ def run(self, timer: Optional[int] = 0): time.sleep(random.uniform(0, timer)) running_flag = os.path.join(self.root, '_RUNNING') success_flag = os.path.join(self.root, '_SUCCESS') - if os.path.isfile(success_flag) or os.path.isfile(running_flag): + if os.path.isfile(running_flag): return + elif os.path.isfile(success_flag): + self.isEval = True elif not os.path.isfile(running_flag): Path(running_flag).touch() @@ -96,7 +99,7 @@ def run(self, timer: Optional[int] = 0): raise Exception('KeyboardInterrupt') # mark experiment as finished. - Path(running_flag).unlink() + Path(running_flag).unlink() if os.path.isfile(running_flag) else None Path(success_flag).touch() def build_experiment(self): diff --git a/experiments/forecast.py b/experiments/forecast.py index e08685a..adf0a14 100644 --- a/experiments/forecast.py +++ b/experiments/forecast.py @@ -36,8 +36,12 @@ def instance(self, datetime_feats=train_set.timestamps.shape[-1]).to(default_device()) checkpoint = Checkpoint(self.root) - # train forecasting task - model = train(model, checkpoint, train_loader, val_loader, test_loader) + if self.isEval: + model.load_state_dict(torch.load(checkpoint.model_path)) + model.eval() + else: + #train forecasting task + model = train(model, checkpoint, train_loader, val_loader, test_loader) # testing val_metrics = validate(model, loader=val_loader, report_metrics=True) @@ -47,8 +51,17 @@ def instance(self, val_metrics = {f'ValMetric/{k}': v for k, v in val_metrics.items()} test_metrics = {f'TestMetric/{k}': v for k, v in test_metrics.items()} + logging.info(f"Validation test set loss from previous model ({test_metrics['TestMetric/mse']} --> {val_metrics['ValMetric/mse']}) test vs val mse.") checkpoint.close({**val_metrics, **test_metrics}) + def load_model(self): + train_set, _ = get_data(flag='train') + model = get_model('deeptime', datetime_feats=train_set.timestamps.shape[-1]).to(default_device()) + checkpoint = Checkpoint(self.root) + model.load_state_dict(torch.load(checkpoint.model_path)) + model.eval() + self.isEval = True + return model @gin.configurable() def get_optimizer(model: nn.Module,