Inference inside callback #19752
Closed
quentinblampey
started this conversation in
General
Replies: 1 comment
-
|
Using class PlotCallback(Callback):
def on_train_epoch_end(self, trainer: L.Trainer, model: Model) -> None:
loader = model.predict_dataloader()
for batch in loader:
batch = model.transfer_batch_to_device(batch, model.device, 0)
model.predict_step(batch)
... # save figure to wandb |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I have a weird issue when running inference with a callback. I guess I'm not using pytorch lightning the intended way.
After each epoch, inside a callback, I run
trainer.predicton a small dataset to plot some figures that I save with weight & biases. I thought it was pretty standard to do so, but I got the error below.I wrote this dummy minimal reproducible example:
Error log
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) Cell In[3], [line 5] [2] callback = PlotCallback() [3] trainer = L.Trainer(max_epochs=2, callbacks=callback, accelerator="cpu") ----> [5] trainer.fit(model) File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) [542](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:542) self.state.status = TrainerStatus.RUNNING [543](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:543) self.training = True --> [544](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:544) call._call_and_handle_interrupt( [545](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:545) self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path [546](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:546) ) File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:44](XXX/python3.9/site-packages/lightning/pytorch/trainer/call.py:44), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) [42](XXX/python3.9/site-packages/lightning/pytorch/trainer/call.py:42) if trainer.strategy.launcher is not None: [43](XXX/python3.9/site-packages/lightning/pytorch/trainer/call.py:43) return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) ---> [44](XXX/python3.9/site-packages/lightning/pytorch/trainer/call.py:44) return trainer_fn(*args, **kwargs) [46](XXX/python3.9/site-packages/lightning/pytorch/trainer/call.py:46) except _TunerExitException: [47](XXX/python3.9/site-packages/lightning/pytorch/trainer/call.py:47) _call_teardown_hook(trainer) File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) [573](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:573) assert self.state.fn is not None [574](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:574) ckpt_path = self._checkpoint_connector._select_ckpt_path( [575](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:575) self.state.fn, [576](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:576) ckpt_path, [577](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:577) model_provided=True, [578](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:578) model_connected=self.lightning_module is not None, [579](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:579) ) --> [580](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:580) self._run(model, ckpt_path=ckpt_path) [582](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:582) assert self.state.stopped [583](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:583) self.training = False File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:987](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:987), in Trainer._run(self, model, ckpt_path) [982](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:982) self._signal_connector.register_signal_handlers() [984](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:984) # ---------------------------- [985](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:985) # RUN THE TRAINER [986](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:986) # ---------------------------- --> [987](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:987) results = self._run_stage() [989](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:989) # ---------------------------- [990](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:990) # POST-Training CLEAN UP [991](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:991) # ---------------------------- [992](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:992) log.debug(f"{self.__class__.__name__}: trainer tearing down") File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1033](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1033), in Trainer._run_stage(self) [1031](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1031) self._run_sanity_check() [1032](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1032) with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> [1033](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1033) self.fit_loop.run() [1034](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1034) return None [1035](XXX/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1035) raise RuntimeError(f"Unexpected state {self.state}") File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:206](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:206), in _FitLoop.run(self) [204](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:204) self.on_advance_start() [205](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:205) self.advance() --> [206](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:206) self.on_advance_end() [207](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:207) self._restarting = False [208](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:208) except StopIteration: File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:380](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:380), in _FitLoop.on_advance_end(self) [377](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:377) call._call_lightning_module_hook(trainer, "on_train_epoch_end") [378](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:378) call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True) --> [380](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:380) trainer._logger_connector.on_epoch_end() [382](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:382) if self.epoch_loop._num_ready_batches_reached(): [383](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:383) # if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint. [384](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:384) # since metric-based schedulers require access to metrics and those are not currently saved in the [385](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:385) # checkpoint, the plateau schedulers shouldn't be updated [386](XXX/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:386) self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting) File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:195](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:195), in _LoggerConnector.on_epoch_end(self) [193](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:193) def on_epoch_end(self) -> None: [194](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:194) assert self._first_loop_iter is None --> [195](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:195) metrics = self.metrics [196](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:196) self._progress_bar_metrics.update(metrics["pbar"]) [197](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:197) self._callback_metrics.update(metrics["callback"]) File [~/Library/Caches/pypoetry/virtualenvs/novae-ezkWKrh6-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:233](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:233), in _LoggerConnector.metrics(self) [231](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:231) """This function returns either batch or epoch metrics.""" [232](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:232) on_step = self._first_loop_iter is not None --> [233](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:233) assert self.trainer._results is not None [234](XXX/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:234) return self.trainer._results.metrics(on_step) AssertionError:Versions:
Beta Was this translation helpful? Give feedback.
All reactions