@@ -104,6 +104,7 @@ def training_step(self, data, batch_idx):
104104 loss = model_out .pop ('loss' )
105105 loss = torch .topk (loss , k = int (loss .shape [0 ] * (1 - self .opt .drop_worst_rate )), largest = False )[0 ].mean ()
106106
107+ # Prepare for logging info
107108 data_time = self .trainer .profiler .recorded_durations ["get_train_batch" ][- 1 ]
108109 data_time = torch .tensor (data_time )
109110
@@ -117,13 +118,11 @@ def training_step(self, data, batch_idx):
117118 logger_logs ['training_loss' ] = loss
118119 logger_logs ['data_time' ] = data_time
119120
120- output = {
121- 'loss' : loss ,
122- 'log' : logger_logs ,
123- 'progress_bar' : {'data_time' : data_time }
124- }
121+ for k , v in logger_logs .items ():
122+ self .log (k , v , on_epoch = (k == 'training_loss' ), prog_bar = (k == 'data_time' ))
123+ # logged
125124
126- return output
125+ return loss
127126
128127 def validation_step (self , data , batch_idx ):
129128 model = self .model
@@ -202,24 +201,25 @@ def validation_step(self, data, batch_idx):
202201 fc_feats , att_feats , att_masks , data ], eval_kwargs )
203202
204203 output = {
205- 'val_loss ' : loss ,
204+ 'loss ' : loss ,
206205 'predictions' : predictions ,
207206 'n_predictions' : n_predictions ,
208207 }
208+ self .log ('loss' , loss )
209209 return output
210210
211211 def test_step (self , * args , ** kwargs ):
212212 return self .validation_step (* args , ** kwargs )
213213
214- def validation_epoch_end (self , outputs , split = 'val' ):
214+ def split_epoch_end (self , outputs , split = 'val' ):
215215 outputs = d2comm .gather (outputs )
216216 # master node
217217 if d2comm .is_main_process ():
218218 assert self .trainer .node_rank == 0 and self .trainer .local_rank == 0
219219 outputs = sum (outputs , [])
220220
221221 opt = self .opt
222- val_loss_mean = sum ([_ ['val_loss' ]
222+ loss_mean = sum ([_ ['loss' ]. item ()
223223 for _ in outputs ]) / len (outputs )
224224
225225 predictions = sum ([_ ['predictions' ] for _ in outputs ], [])
@@ -247,13 +247,13 @@ def validation_epoch_end(self, outputs, split='val'):
247247 if 'CIDEr' in lang_stats :
248248 optimizer .scheduler_step (- lang_stats ['CIDEr' ])
249249 else :
250- optimizer .scheduler_step (val_loss_mean )
250+ optimizer .scheduler_step (loss_mean )
251251
252252 out = {
253- 'val_loss ' : val_loss_mean
253+ 'loss ' : loss_mean
254254 }
255255 out .update (lang_stats )
256- out ['to_monitor' ] = lang_stats ['CIDEr' ] if lang_stats is not None else - val_loss_mean
256+ out ['to_monitor' ] = lang_stats ['CIDEr' ] if lang_stats is not None else - loss_mean
257257 else :
258258 out = {}
259259
@@ -263,23 +263,25 @@ def validation_epoch_end(self, outputs, split='val'):
263263 # must all be tensors
264264 out = {k : torch .tensor (v ) if not torch .is_tensor (
265265 v ) else v for k , v in out .items ()}
266- return {
267- 'progress_bar' : {'val_loss' : out ['val_loss' ]},
268- 'log' : out ,
269- }
266+
267+ return out
268+
269+ def validation_epoch_end (self , outputs ):
270+ out = self .split_epoch_end (outputs , 'val' )
271+ out ['val_loss' ] = out .pop ('loss' )
272+ for k ,v in out .items ():
273+ self .log (k , v )
274+ return out
270275
271276 def test_epoch_end (self , outputs ):
272- out = self .validation_epoch_end (outputs , 'test' )
273- out ['progress_bar' ] = {
274- 'test_loss' : out ['progress_bar' ]['val_loss' ]
275- }
276- out ['log' ]['test_loss' ] = out ['log' ]['val_loss' ]
277- del out ['log' ]['val_loss' ]
278- del out ['log' ]['to_monitor' ]
279- out ['log' ] = {'test_' + k if 'test' not in k else k :v \
280- for k ,v in out ['log' ].items ()}
277+ out = self .split_epoch_end (outputs , 'test' )
278+ out ['test_loss' ] = out .pop ('loss' )
279+ out = {'test_' + k if 'test' not in k else k : v
280+ for k , v in out .items ()}
281+ for k ,v in out .items ():
282+ self .log (k , v )
281283 return out
282-
284+
283285 def configure_optimizers (self ):
284286 opt = self .opt
285287 model = self .model
@@ -309,11 +311,11 @@ def optimizer_step(self, epoch, batch_idx, optimizer,
309311 super ().optimizer_step (epoch , batch_idx , optimizer ,
310312 optimizer_idx , * args , ** kwargs )
311313
312- def state_dict (self ):
314+ def state_dict (self , * args , ** kwargs ):
313315 """
314316 Save the model state dict as well as opt and vocab
315317 """
316- state_dict = self .model .state_dict ()
318+ state_dict = self .model .state_dict (* args , ** kwargs )
317319 device = next (iter (state_dict .values ())).device
318320 assert '_vocab' not in state_dict and '_opt' not in state_dict , 'Just in case'
319321 state_dict .update ({
@@ -345,10 +347,16 @@ def load_state_dict(self, state_dict=None, strict=True):
345347 raise KeyError
346348 self .model .load_state_dict (state_dict , strict )
347349
350+ def get_progress_bar_dict (self ):
351+ # don't show the version number
352+ items = super ().get_progress_bar_dict ()
353+ items .pop ("v_num" , None )
354+ return items
355+
348356
349357class OnEpochStartCallback (pl .Callback ):
350358
351- def on_epoch_start (self , trainer , pl_module ):
359+ def on_train_epoch_start (self , trainer , pl_module ):
352360 # Update lr/training stage/scheduled sampling prob etc.
353361 opt = pl_module .opt
354362 model = pl_module .model
@@ -402,21 +410,32 @@ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
402410
403411 def on_keyboard_interrupt (self , trainer , pl_module ):
404412 # Save model when keyboard interrupt
405- filepath = os .path .join (self .dirpath , self . prefix + 'interrupt .ckpt' )
406- self ._save_model (filepath )
413+ filepath = os .path .join (self .dirpath , pl_module . opt . id + '_interrupt .ckpt' )
414+ self ._save_model (trainer , filepath = filepath )
407415
408416
409417opt = opts .parse_opt ()
410418
411419checkpoint_callback = ModelCheckpoint (
412- filepath = opt .checkpoint_path ,
420+ dirpath = opt .checkpoint_path ,
421+ filename = opt .id + '_{epoch}-{step}' ,
413422 save_last = True ,
414423 save_top_k = 1 ,
415424 verbose = True ,
416425 monitor = 'to_monitor' ,
417426 mode = 'max' ,
418- prefix = opt .id + '_' ,
419427)
428+ checkpoint_callback .CHECKPOINT_NAME_LAST = opt .id + "_last"
429+
430+
431+ tb_logger = pl .loggers .TensorBoardLogger (opt .checkpoint_path +
432+ '/lightning_logs/' ,
433+ name = '' ,
434+ version = 0 )
435+ wandb_logger = pl .loggers .WandbLogger (name = opt .id ,
436+ id = opt .id ,
437+ project = 'captioning' ,
438+ log_model = True )
420439
421440print ("""
422441val_image_use,
@@ -438,29 +457,32 @@ def on_keyboard_interrupt(self, trainer, pl_module):
438457lit = LitModel (opt )
439458# warning grad_clip_mode is ignored.
440459trainer = pl .Trainer (
460+ logger = [tb_logger , wandb_logger ],
441461 callbacks = [
442462 OnEpochStartCallback (),
443- pl .callbacks .lr_logger .LearningRateLogger ()
463+ pl .callbacks .LearningRateMonitor (),
464+ checkpoint_callback ,
444465 ],
445466 default_root_dir = opt .checkpoint_path ,
446467 resume_from_checkpoint = resume_from ,
447- distributed_backend = 'ddp' ,
468+ accelerator = 'ddp' ,
448469 check_val_every_n_epoch = 1 ,
449470 max_epochs = opt .max_epochs ,
471+ gradient_clip_algorithm = opt .grad_clip_mode ,
450472 gradient_clip_val = opt .grad_clip_value ,
451473 gpus = torch .cuda .device_count (),
452- checkpoint_callback = checkpoint_callback ,
453474 log_gpu_memory = 'min_max' ,
454- log_save_interval = opt .losses_log_every ,
455- profiler = True ,
456- row_log_interval = 10 , # what is it?
475+ log_every_n_steps = opt .losses_log_every ,
476+ profiler = 'simple' ,
457477 num_sanity_val_steps = 0 ,
458- # limit_train_batches=500 ,
478+ # limit_train_batches=100 ,
459479 # progress_bar_refresh_rate=0,
460480 # fast_dev_run=True,
461481)
462482
463483if os .getenv ('EVALUATE' , '0' ) == '1' :
484+ lit .load_state_dict (
485+ torch .load (resume_from , map_location = 'cpu' )['state_dict' ], strict = False )
464486 trainer .test (lit )
465487else :
466488 trainer .fit (lit )
0 commit comments