@@ -104,6 +104,7 @@ def training_step(self, data, batch_idx):
104104 loss = model_out .pop ('loss' )
105105 loss = torch .topk (loss , k = int (loss .shape [0 ] * (1 - self .opt .drop_worst_rate )), largest = False )[0 ].mean ()
106106
107+ # Prepare for logging info
107108 data_time = self .trainer .profiler .recorded_durations ["get_train_batch" ][- 1 ]
108109 data_time = torch .tensor (data_time )
109110
@@ -117,13 +118,11 @@ def training_step(self, data, batch_idx):
117118 logger_logs ['training_loss' ] = loss
118119 logger_logs ['data_time' ] = data_time
119120
120- output = {
121- 'loss' : loss ,
122- 'log' : logger_logs ,
123- 'progress_bar' : {'data_time' : data_time }
124- }
121+ for k , v in logger_logs .items ():
122+ self .log (k , v , on_epoch = (k == 'training_loss' ), prog_bar = (k == 'data_time' ))
123+ # logged
125124
126- return output
125+ return loss
127126
128127 def validation_step (self , data , batch_idx ):
129128 model = self .model
@@ -202,7 +201,7 @@ def validation_step(self, data, batch_idx):
202201 fc_feats , att_feats , att_masks , data ], eval_kwargs )
203202
204203 output = {
205- 'val_loss ' : loss ,
204+ 'loss ' : loss ,
206205 'predictions' : predictions ,
207206 'n_predictions' : n_predictions ,
208207 }
@@ -211,15 +210,15 @@ def validation_step(self, data, batch_idx):
211210 def test_step (self , * args , ** kwargs ):
212211 return self .validation_step (* args , ** kwargs )
213212
214- def validation_epoch_end (self , outputs , split = 'val' ):
213+ def split_epoch_end (self , outputs , split = 'val' ):
215214 outputs = d2comm .gather (outputs )
216215 # master node
217216 if d2comm .is_main_process ():
218217 assert self .trainer .node_rank == 0 and self .trainer .local_rank == 0
219218 outputs = sum (outputs , [])
220219
221220 opt = self .opt
222- val_loss_mean = sum ([_ ['val_loss' ]
221+ loss_mean = sum ([_ ['loss' ]. item ()
223222 for _ in outputs ]) / len (outputs )
224223
225224 predictions = sum ([_ ['predictions' ] for _ in outputs ], [])
@@ -247,13 +246,13 @@ def validation_epoch_end(self, outputs, split='val'):
247246 if 'CIDEr' in lang_stats :
248247 optimizer .scheduler_step (- lang_stats ['CIDEr' ])
249248 else :
250- optimizer .scheduler_step (val_loss_mean )
249+ optimizer .scheduler_step (loss_mean )
251250
252251 out = {
253- 'val_loss ' : val_loss_mean
252+ 'loss ' : loss_mean
254253 }
255254 out .update (lang_stats )
256- out ['to_monitor' ] = lang_stats ['CIDEr' ] if lang_stats is not None else - val_loss_mean
255+ out ['to_monitor' ] = lang_stats ['CIDEr' ] if lang_stats is not None else - loss_mean
257256 else :
258257 out = {}
259258
@@ -263,23 +262,25 @@ def validation_epoch_end(self, outputs, split='val'):
263262 # must all be tensors
264263 out = {k : torch .tensor (v ) if not torch .is_tensor (
265264 v ) else v for k , v in out .items ()}
266- return {
267- 'progress_bar' : {'val_loss' : out ['val_loss' ]},
268- 'log' : out ,
269- }
265+
266+ return out
267+
268+ def validation_epoch_end (self , outputs ):
269+ out = self .split_epoch_end (outputs , 'val' )
270+ out ['val_loss' ] = out .pop ('loss' )
271+ for k ,v in out .items ():
272+ self .log (k , v )
273+ return out
270274
271275 def test_epoch_end (self , outputs ):
272- out = self .validation_epoch_end (outputs , 'test' )
273- out ['progress_bar' ] = {
274- 'test_loss' : out ['progress_bar' ]['val_loss' ]
275- }
276- out ['log' ]['test_loss' ] = out ['log' ]['val_loss' ]
277- del out ['log' ]['val_loss' ]
278- del out ['log' ]['to_monitor' ]
279- out ['log' ] = {'test_' + k if 'test' not in k else k :v \
280- for k ,v in out ['log' ].items ()}
276+ out = self .split_epoch_end (outputs , 'test' )
277+ out ['test_loss' ] = out .pop ('loss' )
278+ out = {'test_' + k if 'test' not in k else k : v
279+ for k , v in out .items ()}
280+ for k ,v in out .items ():
281+ self .log (k , v )
281282 return out
282-
283+
283284 def configure_optimizers (self ):
284285 opt = self .opt
285286 model = self .model
@@ -309,11 +310,11 @@ def optimizer_step(self, epoch, batch_idx, optimizer,
309310 super ().optimizer_step (epoch , batch_idx , optimizer ,
310311 optimizer_idx , * args , ** kwargs )
311312
312- def state_dict (self ):
313+ def state_dict (self , * args , ** kwargs ):
313314 """
314315 Save the model state dict as well as opt and vocab
315316 """
316- state_dict = self .model .state_dict ()
317+ state_dict = self .model .state_dict (* args , ** kwargs )
317318 device = next (iter (state_dict .values ())).device
318319 assert '_vocab' not in state_dict and '_opt' not in state_dict , 'Just in case'
319320 state_dict .update ({
@@ -345,10 +346,16 @@ def load_state_dict(self, state_dict=None, strict=True):
345346 raise KeyError
346347 self .model .load_state_dict (state_dict , strict )
347348
349+ def get_progress_bar_dict (self ):
350+ # don't show the version number
351+ items = super ().get_progress_bar_dict ()
352+ items .pop ("v_num" , None )
353+ return items
354+
348355
349356class OnEpochStartCallback (pl .Callback ):
350357
351- def on_epoch_start (self , trainer , pl_module ):
358+ def on_train_epoch_start (self , trainer , pl_module ):
352359 # Update lr/training stage/scheduled sampling prob etc.
353360 opt = pl_module .opt
354361 model = pl_module .model
@@ -402,21 +409,32 @@ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
402409
403410 def on_keyboard_interrupt (self , trainer , pl_module ):
404411 # Save model when keyboard interrupt
405- filepath = os .path .join (self .dirpath , self . prefix + 'interrupt .ckpt' )
406- self ._save_model (filepath )
412+ filepath = os .path .join (self .dirpath , pl_module . opt . id + '_interrupt .ckpt' )
413+ self ._save_model (trainer , filepath = filepath )
407414
408415
409416opt = opts .parse_opt ()
410417
411418checkpoint_callback = ModelCheckpoint (
412- filepath = opt .checkpoint_path ,
419+ dirpath = opt .checkpoint_path ,
420+ filename = opt .id + '_{epoch}-{step}' ,
413421 save_last = True ,
414422 save_top_k = 1 ,
415423 verbose = True ,
416424 monitor = 'to_monitor' ,
417425 mode = 'max' ,
418- prefix = opt .id + '_' ,
419426)
427+ checkpoint_callback .CHECKPOINT_NAME_LAST = opt .id + "_last"
428+
429+
430+ tb_logger = pl .loggers .TensorBoardLogger (opt .checkpoint_path +
431+ '/lightning_logs/' ,
432+ name = '' ,
433+ version = 0 )
434+ wandb_logger = pl .loggers .WandbLogger (name = opt .id ,
435+ id = opt .id ,
436+ project = 'captioning' ,
437+ log_model = True )
420438
421439print ("""
422440val_image_use,
@@ -438,29 +456,32 @@ def on_keyboard_interrupt(self, trainer, pl_module):
438456lit = LitModel (opt )
439457# warning grad_clip_mode is ignored.
440458trainer = pl .Trainer (
459+ logger = [tb_logger , wandb_logger ],
441460 callbacks = [
442461 OnEpochStartCallback (),
443- pl .callbacks .lr_logger .LearningRateLogger ()
462+ pl .callbacks .LearningRateMonitor (),
463+ checkpoint_callback ,
444464 ],
445465 default_root_dir = opt .checkpoint_path ,
446466 resume_from_checkpoint = resume_from ,
447- distributed_backend = 'ddp' ,
467+ accelerator = 'ddp' ,
448468 check_val_every_n_epoch = 1 ,
449469 max_epochs = opt .max_epochs ,
470+ gradient_clip_algorithm = opt .grad_clip_mode ,
450471 gradient_clip_val = opt .grad_clip_value ,
451472 gpus = torch .cuda .device_count (),
452- checkpoint_callback = checkpoint_callback ,
453473 log_gpu_memory = 'min_max' ,
454- log_save_interval = opt .losses_log_every ,
455- profiler = True ,
456- row_log_interval = 10 , # what is it?
474+ log_every_n_steps = opt .losses_log_every ,
475+ profiler = 'simple' ,
457476 num_sanity_val_steps = 0 ,
458- # limit_train_batches=500 ,
477+ # limit_train_batches=100 ,
459478 # progress_bar_refresh_rate=0,
460479 # fast_dev_run=True,
461480)
462481
463482if os .getenv ('EVALUATE' , '0' ) == '1' :
483+ lit .load_state_dict (
484+ torch .load (resume_from , map_location = 'cpu' )['state_dict' ], strict = False )
464485 trainer .test (lit )
465486else :
466487 trainer .fit (lit )
0 commit comments