Skip to content

Commit 8a44016

Browse files
committed
[Intel GPU] enable use of timm dinov2 models for offload benchmark_low_bit_adam
1 parent ed4cd34 commit 8a44016

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

benchmarks/benchmark_low_bit_adam.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,18 @@ def get_parser():
120120

121121
def get_dloader(args, training: bool):
122122
transforms = [v2.ToImage()]
123+
input_size = (
124+
args.data_config["input_size"][1] if args.data_config is not None else 224
125+
) # Standard ViT input size
123126

124127
if training:
125-
transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()])
128+
transforms.extend([v2.RandomResizedCrop(input_size), v2.RandomHorizontalFlip()])
126129
else:
127-
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])
130+
# For validation, resize to slightly larger then center crop
131+
if "dinov2" in args.model.lower():
132+
input_size = 518 # DINOv2 models expect 518x518
133+
resize_size = int(input_size * 256 / 224) # Scale proportionally (584 for 518)
134+
transforms.extend([v2.Resize(resize_size), v2.CenterCrop(input_size)])
128135

129136
transforms.append(v2.ToDtype(torch.float32, scale=True))
130137
transforms.append(
@@ -207,12 +214,15 @@ def evaluate_model(model, args):
207214
dir="/tmp",
208215
mode="disabled" if args.project is None else None,
209216
)
210-
dloader = get_dloader(args, True)
211-
print(f"Train dataset: {len(dloader.dataset):,} images")
212217

213218
model = timm.create_model(
214219
args.model, pretrained=True, num_classes=45, **args.model_kwargs
215220
)
221+
args.data_config = timm.data.resolve_model_data_config(model)
222+
223+
dloader = get_dloader(args, True)
224+
print(f"Train dataset: {len(dloader.dataset):,} images")
225+
216226
if args.checkpoint_activations:
217227
model.set_grad_checkpointing()
218228
if args.full_bf16:

0 commit comments

Comments
 (0)