Skip to content

Commit 257afbb

Browse files
committed
Updated deploy method of version class to support more model types.
1 parent e3605db commit 257afbb

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

roboflow/core/version.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,18 @@ def deploy(self, model_type: str, model_path: str) -> None:
312312
model_path (str): File path to model weights to be uploaded
313313
"""
314314

315-
supported_models = ["yolov8", "yolov5", "yolov7-seg"]
315+
supported_models = ["yolov5", "yolov7", "yolov8"]
316316

317-
if model_type not in supported_models:
317+
if not any(
318+
supported_model in model_type for supported_model in supported_models
319+
):
318320
raise (
319321
ValueError(
320322
f"Model type {model_type} not supported. Supported models are {supported_models}"
321323
)
322324
)
323325

324-
if model_type == "yolov8":
326+
if "yolov8" in model_type:
325327
try:
326328
import torch
327329
import ultralytics
@@ -335,7 +337,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
335337
[("ultralytics", "<=", "8.0.20")]
336338
)
337339

338-
elif model_type in ["yolov5", "yolov7-seg"]:
340+
elif "yolov5" in model_type or "yolov7" in model_type:
339341
try:
340342
import torch
341343
except ImportError as e:
@@ -345,22 +347,31 @@ def deploy(self, model_type: str, model_path: str) -> None:
345347

346348
model = torch.load(os.path.join(model_path, "weights/best.pt"))
347349

348-
class_names = []
349-
for i, val in enumerate(model["model"].names):
350-
class_names.append((val, model["model"].names[val]))
350+
if isinstance(model["model"].names, list):
351+
class_names = model["model"].names
352+
else:
353+
class_names = []
354+
for i, val in enumerate(model["model"].names):
355+
class_names.append((val, model["model"].names[val]))
351356
class_names.sort(key=lambda x: x[0])
352357
class_names = [x[1] for x in class_names]
353358

354-
if model_type == "yolov8":
359+
if "yolov8" in model_type:
355360
# try except for backwards compatibility with older versions of ultralytics
361+
if "-cls" in model_type:
362+
nc = model["model"].yaml["nc"]
363+
args = model["train_args"]
364+
else:
365+
nc = model["model"].nc
366+
args = model["model"].args
356367
try:
357368
model_artifacts = {
358369
"names": class_names,
359370
"yaml": model["model"].yaml,
360-
"nc": model["model"].nc,
371+
"nc": nc,
361372
"args": {
362373
k: val
363-
for k, val in model["model"].args.items()
374+
for k, val in args.items()
364375
if ((k == "model") or (k == "imgsz") or (k == "batch"))
365376
},
366377
"ultralytics_version": ultralytics.__version__,
@@ -370,33 +381,39 @@ def deploy(self, model_type: str, model_path: str) -> None:
370381
model_artifacts = {
371382
"names": class_names,
372383
"yaml": model["model"].yaml,
373-
"nc": model["model"].nc,
384+
"nc": nc,
374385
"args": {
375386
k: val
376-
for k, val in model["model"].args.__dict__.items()
387+
for k, val in args.__dict__.items()
377388
if ((k == "model") or (k == "imgsz") or (k == "batch"))
378389
},
379390
"ultralytics_version": ultralytics.__version__,
380391
"model_type": model_type,
381392
}
382-
elif model_type in ["yolov5", "yolov7-seg"]:
393+
elif "yolov5" in model_type or "yolov7" in model_type:
383394
# parse from yaml for yolov5
384395

385396
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
386397
opts = yaml.safe_load(stream)
387398

388399
model_artifacts = {
389400
"names": class_names,
390-
"yaml": model["model"].yaml,
391401
"nc": model["model"].nc,
392-
"args": {"imgsz": opts["imgsz"], "batch": opts["batch_size"]},
402+
"args": {
403+
"imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"],
404+
"batch": opts["batch_size"],
405+
},
393406
"model_type": model_type,
394407
}
408+
if hasattr(model["model"], "yaml"):
409+
model_artifacts["yaml"] = model["model"].yaml
395410

396-
with open(model_path + "model_artifacts.json", "w") as fp:
411+
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
397412
json.dump(model_artifacts, fp)
398413

399-
torch.save(model["model"].state_dict(), model_path + "state_dict.pt")
414+
torch.save(
415+
model["model"].state_dict(), os.path.join(model_path, "state_dict.pt")
416+
)
400417

401418
lista_files = [
402419
"results.csv",
@@ -405,11 +422,13 @@ def deploy(self, model_type: str, model_path: str) -> None:
405422
"state_dict.pt",
406423
]
407424

408-
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
425+
with zipfile.ZipFile(
426+
os.path.join(model_path, "roboflow_deploy.zip"), "w"
427+
) as zipMe:
409428
for file in lista_files:
410-
if os.path.exists(model_path + file):
429+
if os.path.exists(os.path.join(model_path, file)):
411430
zipMe.write(
412-
model_path + file,
431+
os.path.join(model_path, file),
413432
arcname=file,
414433
compress_type=zipfile.ZIP_DEFLATED,
415434
)
@@ -422,7 +441,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
422441
)
423442

424443
res = requests.get(
425-
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
444+
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}&modelType={model_type}"
426445
)
427446
try:
428447
if res.status_code == 429:
@@ -437,7 +456,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
437456

438457
res = requests.put(
439458
res.json()["url"],
440-
data=open(os.path.join(model_path + "roboflow_deploy.zip"), "rb"),
459+
data=open(os.path.join(model_path, "roboflow_deploy.zip"), "rb"),
441460
)
442461
try:
443462
res.raise_for_status()
@@ -590,7 +609,8 @@ def callback(content: dict) -> dict:
590609
pass
591610
return content
592611

593-
amend_data_yaml(path=data_path, callback=callback)
612+
if os.path.exists(data_path):
613+
amend_data_yaml(path=data_path, callback=callback)
594614

595615
def __str__(self):
596616
"""string representation of version object."""

0 commit comments

Comments
 (0)