Skip to content

Commit f225ae8

Browse files
committed
Update README with model results and attribution. Make scheduler factory bit more robust to arg differences, add noise to plateau lr and fix min/max.
1 parent d1b5ddd commit f225ae8

File tree

4 files changed

+102
-43
lines changed

4 files changed

+102
-43
lines changed

README.md

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,6 @@ Bunch of changes:
6464
### Feb 12, 2020
6565
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
6666

67-
### Feb 6, 2020
68-
* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
69-
70-
### Feb 1/2, 2020
71-
* Port new EfficientNet-B8 (RandAugment) weights, these are different than the B8 AdvProp, different input normalization.
72-
* Update results csv files on all models for ImageNet validation and three other test sets
73-
* Push PyPi package update
74-
75-
### Jan 31, 2020
76-
* Update ResNet50 weights with a new 79.038 result from further JSD / AugMix experiments. Full command line for reproduction in training section below.
77-
78-
### Jan 11/12, 2020
79-
* Master may be a bit unstable wrt to training, these changes have been tested but not all combos
80-
* Implementations of AugMix added to existing RA and AA. Including numerous supporting pieces like JSD loss (Jensen-Shannon divergence + CE), and AugMixDataset
81-
* SplitBatchNorm adaptation layer added for implementing Auxiliary BN as per AdvProp paper
82-
* ResNet-50 AugMix trained model w/ 79% top-1 added
83-
* `seresnext26tn_32x4d` - 77.99 top-1, 93.75 top-5 added to tiered experiment, higher img/s than 't' and 'd'
84-
85-
### Jan 3, 2020
86-
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
87-
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
88-
8967
## Introduction
9068

9169
For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others.
@@ -119,6 +97,7 @@ Included models:
11997
* DenseNet-121, DenseNet-169, DenseNet-201, DenseNet-161
12098
* Squeeze-and-Excitation ResNet/ResNeXt (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) with some pretrained weight additions by myself)
12199
* SENet-154, SE-ResNet-18, SE-ResNet-34, SE-ResNet-50, SE-ResNet-101, SE-ResNet-152, SE-ResNeXt-26 (32x4d), SE-ResNeXt50 (32x4d), SE-ResNeXt101 (32x4d)
100+
* Inception-V3 (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models))
122101
* Inception-ResNet-V2 and Inception-V4 (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) )
123102
* Xception
124103
* Original variant from [Cadene](https://github.com/Cadene/pretrained-models.pytorch)
@@ -143,6 +122,12 @@ Included models:
143122
* code from https://github.com/mehtadushy/SelecSLS-Pytorch, paper https://arxiv.org/abs/1907.00837
144123
* TResNet
145124
* code from https://github.com/mrT23/TResNet, paper https://arxiv.org/abs/2003.13630
125+
* RegNet
126+
* paper `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678
127+
* reference code at https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
128+
* VovNet V2 (with V1 support)
129+
* paper `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
130+
* reference code at https://github.com/youngwanLEE/vovnet-detectron2
146131

147132
Use the `--model` arg to specify model for train, validation, inference scripts. Match the all lowercase
148133
creation fn for the model you'd like.
@@ -187,6 +172,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
187172
| skresnext50d_32x4d | 80.156 (19.844) | 94.642 (5.358) | 27.5M | bicubic | 224 |
188173
| resnext50_32x4d | 79.762 (20.238) | 94.600 (5.400) | 25M | bicubic | 224 |
189174
| resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1M | bicubic | 224 |
175+
| ese_vovnet39b | 79.320 (20.680) | 94.710 (5.290) | 24.6M | bicubic | 224 |
176+
| resnetblur50 | 79.290 (20.710) | 94.632 (5.368) | 25.6M | bicubic | 224 |
190177
| resnet50 | 79.038 (20.962) | 94.390 (5.610) | 25.6M | bicubic | 224 |
191178
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33M | bicubic | 224 |
192179
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.79M | bicubic | 240 |
@@ -200,6 +187,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
200187
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | 224 |
201188
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
202189
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
190+
| densenetblur121d | 76.576 (23.424) | 93.190 (6.810) | 8.0M | bicubic | 224 |
203191
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1M | bicubic | 224 |
204192
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
205193
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |

sotabench.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,24 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
396396
_entry('selecsls60b', 'SelecSLS-60_B', '1907.00837',
397397
model_desc='Originally from https://github.com/mehtadushy/SelecSLS-Pytorch'),
398398

399+
## ResNeSt official impl weights
400+
_entry('resnest14d', 'ResNeSt-14', '2004.08955',
401+
model_desc='Originally from GluonCV'),
402+
_entry('resnest26d', 'ResNeSt-26', '2004.08955',
403+
model_desc='Originally from GluonCV'),
404+
_entry('resnest50d', 'ResNeSt-50', '2004.08955',
405+
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
406+
_entry('resnest101e', 'ResNeSt-101', '2004.08955',
407+
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
408+
_entry('resnest200e', 'ResNeSt-200', '2004.08955',
409+
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
410+
_entry('resnest269e', 'ResNeSt-269', '2004.08955', batch_size=BATCH_SIZE // 2,
411+
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
412+
_entry('resnest50d_4s2x40d', 'ResNeSt-50 4s2x40d', '2004.08955',
413+
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
414+
_entry('resnest50d_1s4x24d', 'ResNeSt-50 1s4x24d', '2004.08955',
415+
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
416+
399417
## RegNet official impl weighs
400418
_entry('regnetx_002', 'RegNetX-200MF', '2003.13678'),
401419
_entry('regnetx_004', 'RegNetX-400MF', '2003.13678'),

timm/scheduler/plateau_lr.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ def __init__(self,
1616
warmup_t=0,
1717
warmup_lr_init=0,
1818
lr_min=0,
19-
mode='min',
19+
mode='max',
20+
noise_range_t=None,
21+
noise_type='normal',
22+
noise_pct=0.67,
23+
noise_std=1.0,
24+
noise_seed=None,
2025
initialize=True,
2126
):
2227
super().__init__(optimizer, 'lr', initialize=initialize)
@@ -32,13 +37,19 @@ def __init__(self,
3237
min_lr=lr_min
3338
)
3439

40+
self.noise_range = noise_range_t
41+
self.noise_pct = noise_pct
42+
self.noise_type = noise_type
43+
self.noise_std = noise_std
44+
self.noise_seed = noise_seed if noise_seed is not None else 42
3545
self.warmup_t = warmup_t
3646
self.warmup_lr_init = warmup_lr_init
3747
if self.warmup_t:
3848
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
3949
super().update_groups(self.warmup_lr_init)
4050
else:
4151
self.warmup_steps = [1 for _ in self.base_values]
52+
self.restore_lr = None
4253

4354
def state_dict(self):
4455
return {
@@ -57,4 +68,40 @@ def step(self, epoch, metric=None):
5768
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
5869
super().update_groups(lrs)
5970
else:
60-
self.lr_scheduler.step(metric, epoch)
71+
if self.restore_lr is not None:
72+
# restore actual LR from before our last noise perturbation before stepping base
73+
for i, param_group in enumerate(self.optimizer.param_groups):
74+
param_group['lr'] = self.restore_lr[i]
75+
self.restore_lr = None
76+
77+
self.lr_scheduler.step(metric, epoch) # step the base scheduler
78+
79+
if self.noise_range is not None:
80+
if isinstance(self.noise_range, (list, tuple)):
81+
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
82+
else:
83+
apply_noise = epoch >= self.noise_range
84+
if apply_noise:
85+
self._apply_noise(epoch)
86+
87+
def _apply_noise(self, epoch):
88+
g = torch.Generator()
89+
g.manual_seed(self.noise_seed + epoch)
90+
if self.noise_type == 'normal':
91+
while True:
92+
# resample if noise out of percent limit, brute force but shouldn't spin much
93+
noise = torch.randn(1, generator=g).item()
94+
if abs(noise) < self.noise_pct:
95+
break
96+
else:
97+
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
98+
99+
# apply the noise on top of previous LR, cache the old value so we can restore for normal
100+
# stepping of base scheduler
101+
restore_lr = []
102+
for i, param_group in enumerate(self.optimizer.param_groups):
103+
old_lr = float(param_group['lr'])
104+
restore_lr.append(old_lr)
105+
new_lr = old_lr + old_lr * noise
106+
param_group['lr'] = new_lr
107+
self.restore_lr = restore_lr

timm/scheduler/scheduler_factory.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,49 @@
77
def create_scheduler(args, optimizer):
88
num_epochs = args.epochs
99

10-
if args.lr_noise is not None:
11-
if isinstance(args.lr_noise, (list, tuple)):
12-
noise_range = [n * num_epochs for n in args.lr_noise]
10+
if getattr(args, 'lr_noise', None) is not None:
11+
lr_noise = getattr(args, 'lr_noise')
12+
if isinstance(lr_noise, (list, tuple)):
13+
noise_range = [n * num_epochs for n in lr_noise]
1314
if len(noise_range) == 1:
1415
noise_range = noise_range[0]
1516
else:
16-
noise_range = args.lr_noise * num_epochs
17+
noise_range = lr_noise * num_epochs
1718
else:
1819
noise_range = None
1920

2021
lr_scheduler = None
21-
#FIXME expose cycle parms of the scheduler config to arguments
2222
if args.sched == 'cosine':
2323
lr_scheduler = CosineLRScheduler(
2424
optimizer,
2525
t_initial=num_epochs,
26-
t_mul=args.lr_cycle_mul,
26+
t_mul=getattr(args, 'lr_cycle_mul', 1.),
2727
lr_min=args.min_lr,
2828
decay_rate=args.decay_rate,
2929
warmup_lr_init=args.warmup_lr,
3030
warmup_t=args.warmup_epochs,
31-
cycle_limit=args.lr_cycle_limit,
31+
cycle_limit=getattr(args, 'lr_cycle_limit', 0),
3232
t_in_epochs=True,
3333
noise_range_t=noise_range,
34-
noise_pct=args.lr_noise_pct,
35-
noise_std=args.lr_noise_std,
36-
noise_seed=args.seed,
34+
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
35+
noise_std=getattr(args, 'lr_noise_std', 1.),
36+
noise_seed=getattr(args, 'seed', 42),
3737
)
3838
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
3939
elif args.sched == 'tanh':
4040
lr_scheduler = TanhLRScheduler(
4141
optimizer,
4242
t_initial=num_epochs,
43-
t_mul=args.lr_cycle_mul,
43+
t_mul=getattr(args, 'lr_cycle_mul', 1.),
4444
lr_min=args.min_lr,
4545
warmup_lr_init=args.warmup_lr,
4646
warmup_t=args.warmup_epochs,
47-
cycle_limit=args.lr_cycle_limit,
47+
cycle_limit=getattr(args, 'lr_cycle_limit', 0),
4848
t_in_epochs=True,
4949
noise_range_t=noise_range,
50-
noise_pct=args.lr_noise_pct,
51-
noise_std=args.lr_noise_std,
52-
noise_seed=args.seed,
50+
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
51+
noise_std=getattr(args, 'lr_noise_std', 1.),
52+
noise_seed=getattr(args, 'seed', 42),
5353
)
5454
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
5555
elif args.sched == 'step':
@@ -60,19 +60,25 @@ def create_scheduler(args, optimizer):
6060
warmup_lr_init=args.warmup_lr,
6161
warmup_t=args.warmup_epochs,
6262
noise_range_t=noise_range,
63-
noise_pct=args.lr_noise_pct,
64-
noise_std=args.lr_noise_std,
65-
noise_seed=args.seed,
63+
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
64+
noise_std=getattr(args, 'lr_noise_std', 1.),
65+
noise_seed=getattr(args, 'seed', 42),
6666
)
6767
elif args.sched == 'plateau':
68+
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
6869
lr_scheduler = PlateauLRScheduler(
6970
optimizer,
7071
decay_rate=args.decay_rate,
7172
patience_t=args.patience_epochs,
7273
lr_min=args.min_lr,
74+
mode=mode,
7375
warmup_lr_init=args.warmup_lr,
7476
warmup_t=args.warmup_epochs,
75-
cooldown_t=args.cooldown_epochs,
77+
cooldown_t=0,
78+
noise_range_t=noise_range,
79+
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
80+
noise_std=getattr(args, 'lr_noise_std', 1.),
81+
noise_seed=getattr(args, 'seed', 42),
7682
)
7783

7884
return lr_scheduler, num_epochs

0 commit comments

Comments
 (0)