Skip to content

Commit de4090f

Browse files
authored
Merge pull request #123 from roboflow/feature/yolo-model-uploads
YOLO Model Deploys
2 parents b22780d + 091f4c4 commit de4090f

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from roboflow.core.workspace import Workspace
1212
from roboflow.util.general import write_line
1313

14-
__version__ = "1.0.0"
14+
__version__ = "1.0.1"
1515

1616

1717
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -441,16 +441,18 @@ def deploy(self, model_type: str, model_path: str) -> None:
441441
model_path (str): File path to model weights to be uploaded
442442
"""
443443

444-
supported_models = ["yolov8", "yolov5", "yolov7-seg"]
444+
supported_models = ["yolov5", "yolov7-seg", "yolov8"]
445445

446-
if model_type not in supported_models:
446+
if not any(
447+
supported_model in model_type for supported_model in supported_models
448+
):
447449
raise (
448450
ValueError(
449451
f"Model type {model_type} not supported. Supported models are {supported_models}"
450452
)
451453
)
452454

453-
if model_type == "yolov8":
455+
if "yolov8" in model_type:
454456
try:
455457
import torch
456458
import ultralytics
@@ -464,7 +466,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
464466
[("ultralytics", "<=", "8.0.20")]
465467
)
466468

467-
elif model_type in ["yolov5", "yolov7-seg"]:
469+
elif "yolov5" in model_type or "yolov7" in model_type:
468470
try:
469471
import torch
470472
except ImportError as e:
@@ -483,16 +485,22 @@ def deploy(self, model_type: str, model_path: str) -> None:
483485
class_names.sort(key=lambda x: x[0])
484486
class_names = [x[1] for x in class_names]
485487

486-
if model_type == "yolov8":
488+
if "yolov8" in model_type:
487489
# try except for backwards compatibility with older versions of ultralytics
490+
if "-cls" in model_type:
491+
nc = model["model"].yaml["nc"]
492+
args = model["train_args"]
493+
else:
494+
nc = model["model"].nc
495+
args = model["model"].args
488496
try:
489497
model_artifacts = {
490498
"names": class_names,
491499
"yaml": model["model"].yaml,
492-
"nc": model["model"].nc,
500+
"nc": nc,
493501
"args": {
494502
k: val
495-
for k, val in model["model"].args.items()
503+
for k, val in args.items()
496504
if ((k == "model") or (k == "imgsz") or (k == "batch"))
497505
},
498506
"ultralytics_version": ultralytics.__version__,
@@ -502,33 +510,39 @@ def deploy(self, model_type: str, model_path: str) -> None:
502510
model_artifacts = {
503511
"names": class_names,
504512
"yaml": model["model"].yaml,
505-
"nc": model["model"].nc,
513+
"nc": nc,
506514
"args": {
507515
k: val
508-
for k, val in model["model"].args.__dict__.items()
516+
for k, val in args.__dict__.items()
509517
if ((k == "model") or (k == "imgsz") or (k == "batch"))
510518
},
511519
"ultralytics_version": ultralytics.__version__,
512520
"model_type": model_type,
513521
}
514-
elif model_type in ["yolov5", "yolov7-seg"]:
522+
elif "yolov5" in model_type or "yolov7" in model_type:
515523
# parse from yaml for yolov5
516524

517525
with open(os.path.join(model_path, "opt.yaml"), "r") as stream:
518526
opts = yaml.safe_load(stream)
519527

520528
model_artifacts = {
521529
"names": class_names,
522-
"yaml": model["model"].yaml,
523530
"nc": model["model"].nc,
524-
"args": {"imgsz": opts["imgsz"], "batch": opts["batch_size"]},
531+
"args": {
532+
"imgsz": opts["imgsz"] if "imgsz" in opts else opts["img_size"],
533+
"batch": opts["batch_size"],
534+
},
525535
"model_type": model_type,
526536
}
537+
if hasattr(model["model"], "yaml"):
538+
model_artifacts["yaml"] = model["model"].yaml
527539

528-
with open(model_path + "model_artifacts.json", "w") as fp:
540+
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
529541
json.dump(model_artifacts, fp)
530542

531-
torch.save(model["model"].state_dict(), model_path + "state_dict.pt")
543+
torch.save(
544+
model["model"].state_dict(), os.path.join(model_path, "state_dict.pt")
545+
)
532546

533547
lista_files = [
534548
"results.csv",
@@ -537,11 +551,13 @@ def deploy(self, model_type: str, model_path: str) -> None:
537551
"state_dict.pt",
538552
]
539553

540-
with zipfile.ZipFile(model_path + "roboflow_deploy.zip", "w") as zipMe:
554+
with zipfile.ZipFile(
555+
os.path.join(model_path, "roboflow_deploy.zip"), "w"
556+
) as zipMe:
541557
for file in lista_files:
542-
if os.path.exists(model_path + file):
558+
if os.path.exists(os.path.join(model_path, file)):
543559
zipMe.write(
544-
model_path + file,
560+
os.path.join(model_path, file),
545561
arcname=file,
546562
compress_type=zipfile.ZIP_DEFLATED,
547563
)
@@ -554,7 +570,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
554570
)
555571

556572
res = requests.get(
557-
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}"
573+
f"{API_URL}/{self.workspace}/{self.project}/{self.version}/uploadModel?api_key={self.__api_key}&modelType={model_type}"
558574
)
559575
try:
560576
if res.status_code == 429:
@@ -569,7 +585,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
569585

570586
res = requests.put(
571587
res.json()["url"],
572-
data=open(os.path.join(model_path + "roboflow_deploy.zip"), "rb"),
588+
data=open(os.path.join(model_path, "roboflow_deploy.zip"), "rb"),
573589
)
574590
try:
575591
res.raise_for_status()

0 commit comments

Comments
 (0)