@@ -312,16 +312,18 @@ def deploy(self, model_type: str, model_path: str) -> None:
312
312
model_path (str): File path to model weights to be uploaded
313
313
"""
314
314
315
- supported_models = ["yolov8 " , "yolov5 " , "yolov7-seg " ]
315
+ supported_models = ["yolov5 " , "yolov7 " , "yolov8 " ]
316
316
317
- if model_type not in supported_models :
317
+ if not any (
318
+ supported_model in model_type for supported_model in supported_models
319
+ ):
318
320
raise (
319
321
ValueError (
320
322
f"Model type { model_type } not supported. Supported models are { supported_models } "
321
323
)
322
324
)
323
325
324
- if model_type == "yolov8" :
326
+ if "yolov8" in model_type :
325
327
try :
326
328
import torch
327
329
import ultralytics
@@ -335,7 +337,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
335
337
[("ultralytics" , "<=" , "8.0.20" )]
336
338
)
337
339
338
- elif model_type in [ "yolov5" , "yolov7-seg" ] :
340
+ elif "yolov5" in model_type or "yolov7" in model_type :
339
341
try :
340
342
import torch
341
343
except ImportError as e :
@@ -345,22 +347,31 @@ def deploy(self, model_type: str, model_path: str) -> None:
345
347
346
348
model = torch .load (os .path .join (model_path , "weights/best.pt" ))
347
349
348
- class_names = []
349
- for i , val in enumerate (model ["model" ].names ):
350
- class_names .append ((val , model ["model" ].names [val ]))
350
+ if isinstance (model ["model" ].names , list ):
351
+ class_names = model ["model" ].names
352
+ else :
353
+ class_names = []
354
+ for i , val in enumerate (model ["model" ].names ):
355
+ class_names .append ((val , model ["model" ].names [val ]))
351
356
class_names .sort (key = lambda x : x [0 ])
352
357
class_names = [x [1 ] for x in class_names ]
353
358
354
- if model_type == "yolov8" :
359
+ if "yolov8" in model_type :
355
360
# try except for backwards compatibility with older versions of ultralytics
361
+ if "-cls" in model_type :
362
+ nc = model ["model" ].yaml ["nc" ]
363
+ args = model ["train_args" ]
364
+ else :
365
+ nc = model ["model" ].nc
366
+ args = model ["model" ].args
356
367
try :
357
368
model_artifacts = {
358
369
"names" : class_names ,
359
370
"yaml" : model ["model" ].yaml ,
360
- "nc" : model [ "model" ]. nc ,
371
+ "nc" : nc ,
361
372
"args" : {
362
373
k : val
363
- for k , val in model [ "model" ]. args .items ()
374
+ for k , val in args .items ()
364
375
if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))
365
376
},
366
377
"ultralytics_version" : ultralytics .__version__ ,
@@ -370,33 +381,39 @@ def deploy(self, model_type: str, model_path: str) -> None:
370
381
model_artifacts = {
371
382
"names" : class_names ,
372
383
"yaml" : model ["model" ].yaml ,
373
- "nc" : model [ "model" ]. nc ,
384
+ "nc" : nc ,
374
385
"args" : {
375
386
k : val
376
- for k , val in model [ "model" ]. args .__dict__ .items ()
387
+ for k , val in args .__dict__ .items ()
377
388
if ((k == "model" ) or (k == "imgsz" ) or (k == "batch" ))
378
389
},
379
390
"ultralytics_version" : ultralytics .__version__ ,
380
391
"model_type" : model_type ,
381
392
}
382
- elif model_type in [ "yolov5" , "yolov7-seg" ] :
393
+ elif "yolov5" in model_type or "yolov7" in model_type :
383
394
# parse from yaml for yolov5
384
395
385
396
with open (os .path .join (model_path , "opt.yaml" ), "r" ) as stream :
386
397
opts = yaml .safe_load (stream )
387
398
388
399
model_artifacts = {
389
400
"names" : class_names ,
390
- "yaml" : model ["model" ].yaml ,
391
401
"nc" : model ["model" ].nc ,
392
- "args" : {"imgsz" : opts ["imgsz" ], "batch" : opts ["batch_size" ]},
402
+ "args" : {
403
+ "imgsz" : opts ["imgsz" ] if "imgsz" in opts else opts ["img_size" ],
404
+ "batch" : opts ["batch_size" ],
405
+ },
393
406
"model_type" : model_type ,
394
407
}
408
+ if hasattr (model ["model" ], "yaml" ):
409
+ model_artifacts ["yaml" ] = model ["model" ].yaml
395
410
396
- with open (model_path + "model_artifacts.json" , "w" ) as fp :
411
+ with open (os . path . join ( model_path , "model_artifacts.json" ) , "w" ) as fp :
397
412
json .dump (model_artifacts , fp )
398
413
399
- torch .save (model ["model" ].state_dict (), model_path + "state_dict.pt" )
414
+ torch .save (
415
+ model ["model" ].state_dict (), os .path .join (model_path , "state_dict.pt" )
416
+ )
400
417
401
418
lista_files = [
402
419
"results.csv" ,
@@ -405,11 +422,13 @@ def deploy(self, model_type: str, model_path: str) -> None:
405
422
"state_dict.pt" ,
406
423
]
407
424
408
- with zipfile .ZipFile (model_path + "roboflow_deploy.zip" , "w" ) as zipMe :
425
+ with zipfile .ZipFile (
426
+ os .path .join (model_path , "roboflow_deploy.zip" ), "w"
427
+ ) as zipMe :
409
428
for file in lista_files :
410
- if os .path .exists (model_path + file ):
429
+ if os .path .exists (os . path . join ( model_path , file ) ):
411
430
zipMe .write (
412
- model_path + file ,
431
+ os . path . join ( model_path , file ) ,
413
432
arcname = file ,
414
433
compress_type = zipfile .ZIP_DEFLATED ,
415
434
)
@@ -422,7 +441,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
422
441
)
423
442
424
443
res = requests .get (
425
- f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } "
444
+ f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } /uploadModel?api_key={ self .__api_key } &modelType= { model_type } "
426
445
)
427
446
try :
428
447
if res .status_code == 429 :
@@ -437,7 +456,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
437
456
438
457
res = requests .put (
439
458
res .json ()["url" ],
440
- data = open (os .path .join (model_path + "roboflow_deploy.zip" ), "rb" ),
459
+ data = open (os .path .join (model_path , "roboflow_deploy.zip" ), "rb" ),
441
460
)
442
461
try :
443
462
res .raise_for_status ()
@@ -590,7 +609,8 @@ def callback(content: dict) -> dict:
590
609
pass
591
610
return content
592
611
593
- amend_data_yaml (path = data_path , callback = callback )
612
+ if os .path .exists (data_path ):
613
+ amend_data_yaml (path = data_path , callback = callback )
594
614
595
615
def __str__ (self ):
596
616
"""string representation of version object."""
0 commit comments