Skip to content

Commit be1a526

Browse files
committed
[pl]: update to pl 1.3+
1 parent 6732836 commit be1a526

File tree

2 files changed

+67
-46
lines changed

2 files changed

+67
-46
lines changed

captioning/utils/misc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def length_average(length, logprobs, alpha=0.):
157157
return logprobs / length
158158

159159

160-
class NoamOpt(object):
160+
class NoamOpt(torch.optim.Optimizer):
161161
"Optim wrapper that implements rate."
162162
def __init__(self, model_size, factor, warmup, optimizer):
163163
self.optimizer = optimizer
@@ -167,14 +167,14 @@ def __init__(self, model_size, factor, warmup, optimizer):
167167
self.model_size = model_size
168168
self._rate = 0
169169

170-
def step(self):
170+
def step(self, *args, **kwargs):
171171
"Update parameters and rate"
172172
self._step += 1
173173
rate = self.rate()
174174
for p in self.optimizer.param_groups:
175175
p['lr'] = rate
176176
self._rate = rate
177-
self.optimizer.step()
177+
self.optimizer.step(*args, **kwargs)
178178

179179
def rate(self, step = None):
180180
"Implement `lrate` above"
@@ -198,16 +198,16 @@ def load_state_dict(self, state_dict):
198198
del state_dict['_step']
199199
self.optimizer.load_state_dict(state_dict)
200200

201-
class ReduceLROnPlateau(object):
201+
class ReduceLROnPlateau(torch.optim.Optimizer):
202202
"Optim wrapper that implements rate."
203203
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
204204
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
205205
self.optimizer = optimizer
206206
self.current_lr = get_lr(optimizer)
207207

208-
def step(self):
208+
def step(self, *args, **kwargs):
209209
"Update parameters and rate"
210-
self.optimizer.step()
210+
self.optimizer.step(*args, **kwargs)
211211

212212
def scheduler_step(self, val):
213213
self.scheduler.step(val)

tools/train_pl.py

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def training_step(self, data, batch_idx):
104104
loss = model_out.pop('loss')
105105
loss = torch.topk(loss, k=int(loss.shape[0] * (1-self.opt.drop_worst_rate)), largest=False)[0].mean()
106106

107+
# Prepare for logging info
107108
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
108109
data_time = torch.tensor(data_time)
109110

@@ -117,13 +118,11 @@ def training_step(self, data, batch_idx):
117118
logger_logs['training_loss'] = loss
118119
logger_logs['data_time'] = data_time
119120

120-
output = {
121-
'loss': loss,
122-
'log': logger_logs,
123-
'progress_bar': {'data_time': data_time}
124-
}
121+
for k, v in logger_logs.items():
122+
self.log(k, v, on_epoch=(k=='training_loss'), prog_bar=(k=='data_time'))
123+
# logged
125124

126-
return output
125+
return loss
127126

128127
def validation_step(self, data, batch_idx):
129128
model = self.model
@@ -202,7 +201,7 @@ def validation_step(self, data, batch_idx):
202201
fc_feats, att_feats, att_masks, data], eval_kwargs)
203202

204203
output = {
205-
'val_loss': loss,
204+
'loss': loss,
206205
'predictions': predictions,
207206
'n_predictions': n_predictions,
208207
}
@@ -211,15 +210,15 @@ def validation_step(self, data, batch_idx):
211210
def test_step(self, *args, **kwargs):
212211
return self.validation_step(*args, **kwargs)
213212

214-
def validation_epoch_end(self, outputs, split='val'):
213+
def split_epoch_end(self, outputs, split='val'):
215214
outputs = d2comm.gather(outputs)
216215
# master node
217216
if d2comm.is_main_process():
218217
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
219218
outputs = sum(outputs, [])
220219

221220
opt = self.opt
222-
val_loss_mean = sum([_['val_loss']
221+
loss_mean = sum([_['loss'].item()
223222
for _ in outputs]) / len(outputs)
224223

225224
predictions = sum([_['predictions'] for _ in outputs], [])
@@ -247,13 +246,13 @@ def validation_epoch_end(self, outputs, split='val'):
247246
if 'CIDEr' in lang_stats:
248247
optimizer.scheduler_step(-lang_stats['CIDEr'])
249248
else:
250-
optimizer.scheduler_step(val_loss_mean)
249+
optimizer.scheduler_step(loss_mean)
251250

252251
out = {
253-
'val_loss': val_loss_mean
252+
'loss': loss_mean
254253
}
255254
out.update(lang_stats)
256-
out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean
255+
out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -loss_mean
257256
else:
258257
out = {}
259258

@@ -263,23 +262,25 @@ def validation_epoch_end(self, outputs, split='val'):
263262
# must all be tensors
264263
out = {k: torch.tensor(v) if not torch.is_tensor(
265264
v) else v for k, v in out.items()}
266-
return {
267-
'progress_bar': {'val_loss': out['val_loss']},
268-
'log': out,
269-
}
265+
266+
return out
267+
268+
def validation_epoch_end(self, outputs):
269+
out = self.split_epoch_end(outputs, 'val')
270+
out['val_loss'] = out.pop('loss')
271+
for k,v in out.items():
272+
self.log(k, v)
273+
return out
270274

271275
def test_epoch_end(self, outputs):
272-
out = self.validation_epoch_end(outputs, 'test')
273-
out['progress_bar'] = {
274-
'test_loss': out['progress_bar']['val_loss']
275-
}
276-
out['log']['test_loss'] = out['log']['val_loss']
277-
del out['log']['val_loss']
278-
del out['log']['to_monitor']
279-
out['log'] = {'test_'+k if 'test' not in k else k:v \
280-
for k,v in out['log'].items()}
276+
out = self.split_epoch_end(outputs, 'test')
277+
out['test_loss'] = out.pop('loss')
278+
out = {'test_'+k if 'test' not in k else k: v
279+
for k, v in out.items()}
280+
for k,v in out.items():
281+
self.log(k, v)
281282
return out
282-
283+
283284
def configure_optimizers(self):
284285
opt = self.opt
285286
model = self.model
@@ -309,11 +310,11 @@ def optimizer_step(self, epoch, batch_idx, optimizer,
309310
super().optimizer_step(epoch, batch_idx, optimizer,
310311
optimizer_idx, *args, **kwargs)
311312

312-
def state_dict(self):
313+
def state_dict(self, *args, **kwargs):
313314
"""
314315
Save the model state dict as well as opt and vocab
315316
"""
316-
state_dict = self.model.state_dict()
317+
state_dict = self.model.state_dict(*args, **kwargs)
317318
device = next(iter(state_dict.values())).device
318319
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
319320
state_dict.update({
@@ -345,10 +346,16 @@ def load_state_dict(self, state_dict=None, strict=True):
345346
raise KeyError
346347
self.model.load_state_dict(state_dict, strict)
347348

349+
def get_progress_bar_dict(self):
350+
# don't show the version number
351+
items = super().get_progress_bar_dict()
352+
items.pop("v_num", None)
353+
return items
354+
348355

349356
class OnEpochStartCallback(pl.Callback):
350357

351-
def on_epoch_start(self, trainer, pl_module):
358+
def on_train_epoch_start(self, trainer, pl_module):
352359
# Update lr/training stage/scheduled sampling prob etc.
353360
opt = pl_module.opt
354361
model = pl_module.model
@@ -402,21 +409,32 @@ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
402409

403410
def on_keyboard_interrupt(self, trainer, pl_module):
404411
# Save model when keyboard interrupt
405-
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
406-
self._save_model(filepath)
412+
filepath = os.path.join(self.dirpath, pl_module.opt.id+'_interrupt.ckpt')
413+
self._save_model(trainer, filepath=filepath)
407414

408415

409416
opt = opts.parse_opt()
410417

411418
checkpoint_callback = ModelCheckpoint(
412-
filepath=opt.checkpoint_path,
419+
dirpath=opt.checkpoint_path,
420+
filename=opt.id+'_{epoch}-{step}',
413421
save_last=True,
414422
save_top_k=1,
415423
verbose=True,
416424
monitor='to_monitor',
417425
mode='max',
418-
prefix=opt.id+'_',
419426
)
427+
checkpoint_callback.CHECKPOINT_NAME_LAST = opt.id+"_last"
428+
429+
430+
tb_logger = pl.loggers.TensorBoardLogger(opt.checkpoint_path +
431+
'/lightning_logs/',
432+
name='',
433+
version=0)
434+
wandb_logger = pl.loggers.WandbLogger(name=opt.id,
435+
id=opt.id,
436+
project='captioning',
437+
log_model=True)
420438

421439
print("""
422440
val_image_use,
@@ -438,29 +456,32 @@ def on_keyboard_interrupt(self, trainer, pl_module):
438456
lit = LitModel(opt)
439457
# warning grad_clip_mode is ignored.
440458
trainer = pl.Trainer(
459+
logger=[tb_logger, wandb_logger],
441460
callbacks=[
442461
OnEpochStartCallback(),
443-
pl.callbacks.lr_logger.LearningRateLogger()
462+
pl.callbacks.LearningRateMonitor(),
463+
checkpoint_callback,
444464
],
445465
default_root_dir=opt.checkpoint_path,
446466
resume_from_checkpoint=resume_from,
447-
distributed_backend='ddp',
467+
accelerator='ddp',
448468
check_val_every_n_epoch=1,
449469
max_epochs=opt.max_epochs,
470+
gradient_clip_algorithm=opt.grad_clip_mode,
450471
gradient_clip_val=opt.grad_clip_value,
451472
gpus=torch.cuda.device_count(),
452-
checkpoint_callback=checkpoint_callback,
453473
log_gpu_memory='min_max',
454-
log_save_interval=opt.losses_log_every,
455-
profiler=True,
456-
row_log_interval=10, # what is it?
474+
log_every_n_steps=opt.losses_log_every,
475+
profiler='simple',
457476
num_sanity_val_steps=0,
458-
# limit_train_batches=500,
477+
# limit_train_batches=100,
459478
# progress_bar_refresh_rate=0,
460479
# fast_dev_run=True,
461480
)
462481

463482
if os.getenv('EVALUATE', '0') == '1':
483+
lit.load_state_dict(
484+
torch.load(resume_from, map_location='cpu')['state_dict'], strict=False)
464485
trainer.test(lit)
465486
else:
466487
trainer.fit(lit)

0 commit comments

Comments
 (0)