Skip to content

Commit 93b6b32

Browse files
[Conformance] Ultralytics yolov8n and yolo11n
1 parent f61aa89 commit 93b6b32

File tree

4 files changed

+220
-0
lines changed

4 files changed

+220
-0
lines changed

tests/post_training/data/ptq_reference_data.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ torchvision/swin_v2_s_backend_OV:
5858
metric_value: 0.83638
5959
torchvision/swin_v2_s_backend_FX_TORCH:
6060
metric_value: 0.8360
61+
ultralytics/yolov8n_backend_FP32:
62+
metric_value: 0.6056
63+
ultralytics/yolov8n_backend_FX_TORCH:
64+
metric_value: 0.61417
65+
ultralytics/yolov8n_backend_OV:
66+
metric_value: 0.6188
67+
ultralytics/yolo11n_backend_FP32:
68+
metric_value: 0.6770
69+
ultralytics/yolo11n_backend_FX_TORCH:
70+
metric_value: 0.6735
71+
ultralytics/yolo11n_backend_OV:
72+
metric_value: 0.6752
6173
timm/crossvit_9_240_backend_CUDA_TORCH:
6274
metric_value: 0.7275
6375
timm/crossvit_9_240_backend_FP32:

tests/post_training/model_scope.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tests.post_training.pipelines.image_classification_torchvision import ImageClassificationTorchvision
3333
from tests.post_training.pipelines.lm_weight_compression import LMWeightCompression
3434
from tests.post_training.pipelines.masked_language_modeling import MaskedLanguageModelingHF
35+
from tests.post_training.pipelines.ultralytics_detection import UltralyticsDetection
3536

3637
QUANTIZATION_MODELS = [
3738
# HF models
@@ -123,6 +124,85 @@
123124
"backends": [BackendType.FX_TORCH, BackendType.OV],
124125
"batch_size": 1,
125126
},
127+
# Ultralytics models
128+
{
129+
"reported_name": "ultralytics/yolov8n",
130+
"model_id": "yolov8n",
131+
"pipeline_cls": UltralyticsDetection,
132+
"compression_params": {
133+
"preset": nncf.QuantizationPreset.MIXED,
134+
"ignored_scope": nncf.IgnoredScope(
135+
types=["mul", "sub", "sigmoid", "__getitem__"],
136+
subgraphs=[
137+
nncf.Subgraph(
138+
inputs=["cat_13", "cat_14", "cat_15"],
139+
outputs=["output"],
140+
)
141+
],
142+
),
143+
},
144+
"backends": [BackendType.FX_TORCH],
145+
"batch_size": 1,
146+
},
147+
{
148+
"reported_name": "ultralytics/yolov8n",
149+
"model_id": "yolov8n",
150+
"pipeline_cls": UltralyticsDetection,
151+
"compression_params": {
152+
"preset": QuantizationPreset.MIXED,
153+
"ignored_scope": nncf.IgnoredScope(
154+
types=["Multiply", "Subtract", "Sigmoid"],
155+
subgraphs=[
156+
nncf.Subgraph(
157+
inputs=["/model.22/Concat", "/model.22/Concat_1", "/model.22/Concat_2"],
158+
outputs=["output0/sink_port_0"],
159+
)
160+
],
161+
),
162+
},
163+
"backends": [BackendType.OV],
164+
"batch_size": 1,
165+
},
166+
{
167+
"reported_name": "ultralytics/yolo11n",
168+
"model_id": "yolo11n",
169+
"pipeline_cls": UltralyticsDetection,
170+
"compression_params": {
171+
"model_type": nncf.ModelType.TRANSFORMER,
172+
"preset": QuantizationPreset.MIXED,
173+
"ignored_scope": nncf.IgnoredScope(
174+
types=["mul", "sub", "sigmoid", "__getitem__"],
175+
subgraphs=[
176+
nncf.Subgraph(
177+
inputs=["cat_13", "cat_14", "cat_15"],
178+
outputs=["output"],
179+
)
180+
],
181+
),
182+
},
183+
"backends": [BackendType.FX_TORCH],
184+
"batch_size": 1,
185+
},
186+
{
187+
"reported_name": "ultralytics/yolo11n",
188+
"model_id": "yolo11n",
189+
"pipeline_cls": UltralyticsDetection,
190+
"compression_params": {
191+
"model_type": nncf.ModelType.TRANSFORMER,
192+
"preset": QuantizationPreset.MIXED,
193+
"ignored_scope": nncf.IgnoredScope(
194+
types=["Multiply", "Subtract", "Sigmoid"],
195+
subgraphs=[
196+
nncf.Subgraph(
197+
inputs=["/model.23/Concat", "/model.23/Concat_1", "/model.23/Concat_2"],
198+
outputs=["output0/sink_port_0"],
199+
)
200+
],
201+
),
202+
},
203+
"backends": [BackendType.OV],
204+
"batch_size": 1,
205+
},
126206
# Timm models
127207
{
128208
"reported_name": "timm/crossvit_9_240",
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from pathlib import Path
13+
from typing import Dict, Tuple
14+
15+
import openvino as ov
16+
import torch
17+
from ultralytics import YOLO
18+
from ultralytics.data.utils import check_det_dataset
19+
from ultralytics.engine.validator import BaseValidator as Validator
20+
from ultralytics.utils.torch_utils import de_parallel
21+
22+
import nncf
23+
from nncf.torch import disable_patching
24+
from tests.post_training.pipelines.base import OV_BACKENDS
25+
from tests.post_training.pipelines.base import BackendType
26+
from tests.post_training.pipelines.base import PTQTestPipeline
27+
28+
29+
class UltralyticsDetection(PTQTestPipeline):
30+
"""Pipeline for Yolo detection models from the Ultralytics repository"""
31+
32+
def prepare_model(self) -> None:
33+
if self.batch_size != 1:
34+
raise RuntimeError("Batch size > 1 is not supported")
35+
36+
model_path = f"{self.fp32_model_dir}/{self.model_id}"
37+
yolo = YOLO(f"{model_path}.pt")
38+
self.validator, self.data_loader = self._prepare_validation(yolo, "coco128.yaml")
39+
self.dummy_tensor = torch.ones((1, 3, 640, 640))
40+
41+
if self.backend in OV_BACKENDS + [BackendType.FP32]:
42+
onnx_model_path = Path(f"{model_path}.onnx")
43+
ir_model_path = self.fp32_model_dir / "model_fp32.xml"
44+
yolo.export(format="onnx", dynamic=True, half=False)
45+
ov.save_model(ov.convert_model(onnx_model_path), ir_model_path)
46+
self.model = ov.Core().read_model(ir_model_path)
47+
48+
if self.backend == BackendType.FX_TORCH:
49+
pt_model = yolo.model
50+
# Run mode one time to initialize all
51+
# internal variables
52+
pt_model(self.dummy_tensor)
53+
54+
with torch.no_grad():
55+
with disable_patching():
56+
self.model = torch.export.export(pt_model, args=(self.dummy_tensor,), strict=False).module()
57+
58+
def prepare_preprocessor(self) -> None:
59+
pass
60+
61+
@staticmethod
62+
def _validate_fx(
63+
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
64+
) -> Tuple[Dict, int, int]:
65+
for batch_i, batch in enumerate(data_loader):
66+
if num_samples is not None and batch_i == num_samples:
67+
break
68+
batch = validator.preprocess(batch)
69+
preds = model(batch["img"])
70+
preds = validator.postprocess(preds)
71+
validator.update_metrics(preds, batch)
72+
stats = validator.get_stats()
73+
return stats, validator.seen, validator.nt_per_class.sum()
74+
75+
@staticmethod
76+
def _validate_ov(
77+
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
78+
) -> Tuple[Dict, int, int]:
79+
model.reshape({0: [1, 3, -1, -1]})
80+
compiled_model = ov.compile_model(model)
81+
output_layer = compiled_model.output(0)
82+
for batch_i, batch in enumerate(data_loader):
83+
if num_samples is not None and batch_i == num_samples:
84+
break
85+
batch = validator.preprocess(batch)
86+
preds = torch.from_numpy(compiled_model(batch["img"])[output_layer])
87+
preds = validator.postprocess(preds)
88+
validator.update_metrics(preds, batch)
89+
stats = validator.get_stats()
90+
return stats, validator.seen, validator.nt_per_class.sum()
91+
92+
def get_transform_calibration_fn(self):
93+
def transform_func(batch):
94+
return self.validator.preprocess(batch)["img"]
95+
96+
return transform_func
97+
98+
def prepare_calibration_dataset(self):
99+
self.calibration_dataset = nncf.Dataset(self.data_loader, self.get_transform_calibration_fn())
100+
101+
@staticmethod
102+
def _prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]:
103+
custom = {"rect": False, "batch": 1} # method defaults
104+
args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right
105+
106+
validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks)
107+
stride = 32 # default stride
108+
validator.stride = stride # used in get_dataloader() for padding
109+
validator.data = check_det_dataset(data)
110+
validator.init_metrics(de_parallel(model))
111+
112+
data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch)
113+
114+
return validator, data_loader
115+
116+
def _validate(self):
117+
if self.backend == BackendType.FP32:
118+
stats, _, _ = self._validate_ov(self.model, self.data_loader, self.validator)
119+
elif self.backend in OV_BACKENDS:
120+
stats, _, _ = self._validate_ov(self.compressed_model, self.data_loader, self.validator)
121+
elif self.backend == BackendType.FX_TORCH:
122+
stats, _, _ = self._validate_fx(self.compressed_model, self.data_loader, self.validator)
123+
else:
124+
raise RuntimeError(f"Backend {self.backend} is not supported in UltralyticsDetection")
125+
126+
self.run_info.metric_name = "mAP50(B)"
127+
self.run_info.metric_value = stats["metrics/mAP50(B)"]

tests/post_training/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ timm==0.9.2
1919
transformers==4.38.2
2020
whowhatbench @ git+https://github.com/andreyanufr/who_what_benchmark@456d3584ce628f6c8605f37cd9a3ab2db1ebf933
2121
datasets==2.21.0
22+
ultralytics==8.3.38

0 commit comments

Comments
 (0)