Skip to content

Commit 5076fbc

Browse files
committed
[pl]: update to pl 1.3+
1 parent b2fe2bf commit 5076fbc

File tree

2 files changed

+68
-46
lines changed

2 files changed

+68
-46
lines changed

captioning/utils/misc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def length_average(length, logprobs, alpha=0.):
157157
return logprobs / length
158158

159159

160-
class NoamOpt(object):
160+
class NoamOpt(torch.optim.Optimizer):
161161
"Optim wrapper that implements rate."
162162
def __init__(self, model_size, factor, warmup, optimizer):
163163
self.optimizer = optimizer
@@ -167,14 +167,14 @@ def __init__(self, model_size, factor, warmup, optimizer):
167167
self.model_size = model_size
168168
self._rate = 0
169169

170-
def step(self):
170+
def step(self, *args, **kwargs):
171171
"Update parameters and rate"
172172
self._step += 1
173173
rate = self.rate()
174174
for p in self.optimizer.param_groups:
175175
p['lr'] = rate
176176
self._rate = rate
177-
self.optimizer.step()
177+
self.optimizer.step(*args, **kwargs)
178178

179179
def rate(self, step = None):
180180
"Implement `lrate` above"
@@ -198,16 +198,16 @@ def load_state_dict(self, state_dict):
198198
del state_dict['_step']
199199
self.optimizer.load_state_dict(state_dict)
200200

201-
class ReduceLROnPlateau(object):
201+
class ReduceLROnPlateau(torch.optim.Optimizer):
202202
"Optim wrapper that implements rate."
203203
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
204204
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
205205
self.optimizer = optimizer
206206
self.current_lr = get_lr(optimizer)
207207

208-
def step(self):
208+
def step(self, *args, **kwargs):
209209
"Update parameters and rate"
210-
self.optimizer.step()
210+
self.optimizer.step(*args, **kwargs)
211211

212212
def scheduler_step(self, val):
213213
self.scheduler.step(val)

tools/train_pl.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def training_step(self, data, batch_idx):
104104
loss = model_out.pop('loss')
105105
loss = torch.topk(loss, k=int(loss.shape[0] * (1-self.opt.drop_worst_rate)), largest=False)[0].mean()
106106

107+
# Prepare for logging info
107108
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
108109
data_time = torch.tensor(data_time)
109110

@@ -117,13 +118,11 @@ def training_step(self, data, batch_idx):
117118
logger_logs['training_loss'] = loss
118119
logger_logs['data_time'] = data_time
119120

120-
output = {
121-
'loss': loss,
122-
'log': logger_logs,
123-
'progress_bar': {'data_time': data_time}
124-
}
121+
for k, v in logger_logs.items():
122+
self.log(k, v, on_epoch=(k=='training_loss'), prog_bar=(k=='data_time'))
123+
# logged
125124

126-
return output
125+
return loss
127126

128127
def validation_step(self, data, batch_idx):
129128
model = self.model
@@ -202,24 +201,25 @@ def validation_step(self, data, batch_idx):
202201
fc_feats, att_feats, att_masks, data], eval_kwargs)
203202

204203
output = {
205-
'val_loss': loss,
204+
'loss': loss,
206205
'predictions': predictions,
207206
'n_predictions': n_predictions,
208207
}
208+
self.log('loss', loss)
209209
return output
210210

211211
def test_step(self, *args, **kwargs):
212212
return self.validation_step(*args, **kwargs)
213213

214-
def validation_epoch_end(self, outputs, split='val'):
214+
def split_epoch_end(self, outputs, split='val'):
215215
outputs = d2comm.gather(outputs)
216216
# master node
217217
if d2comm.is_main_process():
218218
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
219219
outputs = sum(outputs, [])
220220

221221
opt = self.opt
222-
val_loss_mean = sum([_['val_loss']
222+
loss_mean = sum([_['loss'].item()
223223
for _ in outputs]) / len(outputs)
224224

225225
predictions = sum([_['predictions'] for _ in outputs], [])
@@ -247,13 +247,13 @@ def validation_epoch_end(self, outputs, split='val'):
247247
if 'CIDEr' in lang_stats:
248248
optimizer.scheduler_step(-lang_stats['CIDEr'])
249249
else:
250-
optimizer.scheduler_step(val_loss_mean)
250+
optimizer.scheduler_step(loss_mean)
251251

252252
out = {
253-
'val_loss': val_loss_mean
253+
'loss': loss_mean
254254
}
255255
out.update(lang_stats)
256-
out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean
256+
out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -loss_mean
257257
else:
258258
out = {}
259259

@@ -263,23 +263,25 @@ def validation_epoch_end(self, outputs, split='val'):
263263
# must all be tensors
264264
out = {k: torch.tensor(v) if not torch.is_tensor(
265265
v) else v for k, v in out.items()}
266-
return {
267-
'progress_bar': {'val_loss': out['val_loss']},
268-
'log': out,
269-
}
266+
267+
return out
268+
269+
def validation_epoch_end(self, outputs):
270+
out = self.split_epoch_end(outputs, 'val')
271+
out['val_loss'] = out.pop('loss')
272+
for k,v in out.items():
273+
self.log(k, v)
274+
return out
270275

271276
def test_epoch_end(self, outputs):
272-
out = self.validation_epoch_end(outputs, 'test')
273-
out['progress_bar'] = {
274-
'test_loss': out['progress_bar']['val_loss']
275-
}
276-
out['log']['test_loss'] = out['log']['val_loss']
277-
del out['log']['val_loss']
278-
del out['log']['to_monitor']
279-
out['log'] = {'test_'+k if 'test' not in k else k:v \
280-
for k,v in out['log'].items()}
277+
out = self.split_epoch_end(outputs, 'test')
278+
out['test_loss'] = out.pop('loss')
279+
out = {'test_'+k if 'test' not in k else k: v
280+
for k, v in out.items()}
281+
for k,v in out.items():
282+
self.log(k, v)
281283
return out
282-
284+
283285
def configure_optimizers(self):
284286
opt = self.opt
285287
model = self.model
@@ -309,11 +311,11 @@ def optimizer_step(self, epoch, batch_idx, optimizer,
309311
super().optimizer_step(epoch, batch_idx, optimizer,
310312
optimizer_idx, *args, **kwargs)
311313

312-
def state_dict(self):
314+
def state_dict(self, *args, **kwargs):
313315
"""
314316
Save the model state dict as well as opt and vocab
315317
"""
316-
state_dict = self.model.state_dict()
318+
state_dict = self.model.state_dict(*args, **kwargs)
317319
device = next(iter(state_dict.values())).device
318320
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
319321
state_dict.update({
@@ -345,10 +347,16 @@ def load_state_dict(self, state_dict=None, strict=True):
345347
raise KeyError
346348
self.model.load_state_dict(state_dict, strict)
347349

350+
def get_progress_bar_dict(self):
351+
# don't show the version number
352+
items = super().get_progress_bar_dict()
353+
items.pop("v_num", None)
354+
return items
355+
348356

349357
class OnEpochStartCallback(pl.Callback):
350358

351-
def on_epoch_start(self, trainer, pl_module):
359+
def on_train_epoch_start(self, trainer, pl_module):
352360
# Update lr/training stage/scheduled sampling prob etc.
353361
opt = pl_module.opt
354362
model = pl_module.model
@@ -402,21 +410,32 @@ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
402410

403411
def on_keyboard_interrupt(self, trainer, pl_module):
404412
# Save model when keyboard interrupt
405-
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
406-
self._save_model(filepath)
413+
filepath = os.path.join(self.dirpath, pl_module.opt.id+'_interrupt.ckpt')
414+
self._save_model(trainer, filepath=filepath)
407415

408416

409417
opt = opts.parse_opt()
410418

411419
checkpoint_callback = ModelCheckpoint(
412-
filepath=opt.checkpoint_path,
420+
dirpath=opt.checkpoint_path,
421+
filename=opt.id+'_{epoch}-{step}',
413422
save_last=True,
414423
save_top_k=1,
415424
verbose=True,
416425
monitor='to_monitor',
417426
mode='max',
418-
prefix=opt.id+'_',
419427
)
428+
checkpoint_callback.CHECKPOINT_NAME_LAST = opt.id+"_last"
429+
430+
431+
tb_logger = pl.loggers.TensorBoardLogger(opt.checkpoint_path +
432+
'/lightning_logs/',
433+
name='',
434+
version=0)
435+
wandb_logger = pl.loggers.WandbLogger(name=opt.id,
436+
id=opt.id,
437+
project='captioning',
438+
log_model=True)
420439

421440
print("""
422441
val_image_use,
@@ -438,29 +457,32 @@ def on_keyboard_interrupt(self, trainer, pl_module):
438457
lit = LitModel(opt)
439458
# warning grad_clip_mode is ignored.
440459
trainer = pl.Trainer(
460+
logger=[tb_logger, wandb_logger],
441461
callbacks=[
442462
OnEpochStartCallback(),
443-
pl.callbacks.lr_logger.LearningRateLogger()
463+
pl.callbacks.LearningRateMonitor(),
464+
checkpoint_callback,
444465
],
445466
default_root_dir=opt.checkpoint_path,
446467
resume_from_checkpoint=resume_from,
447-
distributed_backend='ddp',
468+
accelerator='ddp',
448469
check_val_every_n_epoch=1,
449470
max_epochs=opt.max_epochs,
471+
gradient_clip_algorithm=opt.grad_clip_mode,
450472
gradient_clip_val=opt.grad_clip_value,
451473
gpus=torch.cuda.device_count(),
452-
checkpoint_callback=checkpoint_callback,
453474
log_gpu_memory='min_max',
454-
log_save_interval=opt.losses_log_every,
455-
profiler=True,
456-
row_log_interval=10, # what is it?
475+
log_every_n_steps=opt.losses_log_every,
476+
profiler='simple',
457477
num_sanity_val_steps=0,
458-
# limit_train_batches=500,
478+
# limit_train_batches=100,
459479
# progress_bar_refresh_rate=0,
460480
# fast_dev_run=True,
461481
)
462482

463483
if os.getenv('EVALUATE', '0') == '1':
484+
lit.load_state_dict(
485+
torch.load(resume_from, map_location='cpu')['state_dict'], strict=False)
464486
trainer.test(lit)
465487
else:
466488
trainer.fit(lit)

0 commit comments

Comments
 (0)