@@ -120,11 +120,18 @@ def get_parser():
120120
121121def get_dloader (args , training : bool ):
122122 transforms = [v2 .ToImage ()]
123+ input_size = (
124+ args .data_config ["input_size" ][1 ] if args .data_config is not None else 224
125+ ) # Standard ViT input size
123126
124127 if training :
125- transforms .extend ([v2 .RandomResizedCrop (224 ), v2 .RandomHorizontalFlip ()])
128+ transforms .extend ([v2 .RandomResizedCrop (input_size ), v2 .RandomHorizontalFlip ()])
126129 else :
127- transforms .extend ([v2 .Resize (256 ), v2 .CenterCrop (224 )])
130+ # For validation, resize to slightly larger then center crop
131+ if "dinov2" in args .model .lower ():
132+ input_size = 518 # DINOv2 models expect 518x518
133+ resize_size = int (input_size * 256 / 224 ) # Scale proportionally (584 for 518)
134+ transforms .extend ([v2 .Resize (resize_size ), v2 .CenterCrop (input_size )])
128135
129136 transforms .append (v2 .ToDtype (torch .float32 , scale = True ))
130137 transforms .append (
@@ -207,12 +214,15 @@ def evaluate_model(model, args):
207214 dir = "/tmp" ,
208215 mode = "disabled" if args .project is None else None ,
209216 )
210- dloader = get_dloader (args , True )
211- print (f"Train dataset: { len (dloader .dataset ):,} images" )
212217
213218 model = timm .create_model (
214219 args .model , pretrained = True , num_classes = 45 , ** args .model_kwargs
215220 )
221+ args .data_config = timm .data .resolve_model_data_config (model )
222+
223+ dloader = get_dloader (args , True )
224+ print (f"Train dataset: { len (dloader .dataset ):,} images" )
225+
216226 if args .checkpoint_activations :
217227 model .set_grad_checkpointing ()
218228 if args .full_bf16 :
0 commit comments