Skip to content

Commit e3ef192

Browse files
gruebelvfdev-5
andauthored
Add None check for max_epochs (#1519)
* Add None check for max_epochs * Small CR fix Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent cfea1e4 commit e3ef192

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

ignite/contrib/handlers/param_scheduler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None
6565
if isinstance(value, list):
6666
if len(value) != len(self.optimizer_param_groups):
6767
raise ValueError(
68-
f"size of value is different than optimizer_param_groups {len(value)} != {len(self.optimizer_param_groups)}"
68+
"size of value is different than optimizer_param_groups "
69+
f"{len(value)} != {len(self.optimizer_param_groups)}"
6970
)
7071

7172
for i, param_group in enumerate(self.optimizer_param_groups):

ignite/engine/engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,8 @@ def _is_done(state: State) -> bool:
561561
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
562562
is_done_count = (
563563
state.epoch_length is not None
564-
and state.iteration >= state.epoch_length * state.max_epochs # type: ignore[operator]
564+
and state.max_epochs is not None
565+
and state.iteration >= state.epoch_length * state.max_epochs
565566
)
566567
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
567568
return is_done_iters or is_done_count or is_done_epochs
@@ -833,12 +834,17 @@ def _run_once_on_dataset(self) -> float:
833834
# Should exit while loop if we can not iterate
834835
if should_exit:
835836
if not self._is_done(self.state):
837+
total_iters = (
838+
self.state.epoch_length * self.state.max_epochs
839+
if self.state.max_epochs is not None
840+
else self.state.max_iters
841+
)
842+
836843
warnings.warn(
837844
"Data iterator can not provide data anymore but required total number of "
838845
"iterations to run is not reached. "
839846
"Current iteration: {} vs Total iterations to run : {}".format(
840-
self.state.iteration,
841-
self.state.epoch_length * self.state.max_epochs, # type: ignore[operator]
847+
self.state.iteration, total_iters,
842848
)
843849
)
844850
break

0 commit comments

Comments
 (0)