diff --git a/main.py b/main.py index 9578208..12e39a7 100644 --- a/main.py +++ b/main.py @@ -66,6 +66,10 @@ parser.add_argument('--num_freq_disp', type=int, default=10, dest='num_freq_disp') parser.add_argument('--num_freq_save', type=int, default=50, dest='num_freq_save') +parser.add_argument('--size_window_x', type=int, default=5, dest='size_window_x') + +parser.add_argument('--size_window_y', type=int, default=5, dest='size_window_y') + PARSER = Parser(parser) def main(): diff --git a/train.py b/train.py index 78c34c0..10e5894 100644 --- a/train.py +++ b/train.py @@ -51,6 +51,9 @@ def __init__(self, args): self.num_freq_disp = args.num_freq_disp self.num_freq_save = args.num_freq_save + self.size_window_x = args.size_window_x + self.size_window_y = args.size_window_y + self.gpu_ids = args.gpu_ids if self.gpu_ids and torch.cuda.is_available(): @@ -114,7 +117,7 @@ def train(self): nch_ker = self.nch_ker size_data = (self.ny_in, self.nx_in, self.nch_in) - size_window = (5, 5) + size_window = (self.size_window_x, self.size_window_y) norm = self.norm name_data = self.name_data @@ -329,7 +332,7 @@ def test(self): nch_ker = self.nch_ker size_data = (self.ny_in, self.nx_in, self.nch_in) - size_window = (5, 5) + size_window = (self.size_window_x, self.size_window_y) norm = self.norm