Skip to content

Commit 6422b02

Browse files
committed
Fix most mypy errors
1 parent 212c7c8 commit 6422b02

File tree

11 files changed

+65
-55
lines changed

11 files changed

+65
-55
lines changed

roboflow/adapters/rfapi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def upload_image(
4343
split: str = "train",
4444
batch_name: str = DEFAULT_BATCH_NAME,
4545
tag_names: list = [],
46-
sequence_number: int = None,
47-
sequence_size: int = None,
46+
sequence_number: int | None = None,
47+
sequence_size: int | None = None,
4848
**kwargs,
4949
):
5050
"""

roboflow/core/project.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def custom_formatwarning(msg, *args, **kwargs):
2222
return str(msg) + "\n"
2323

2424

25-
warnings.formatwarning = custom_formatwarning
25+
warnings.formatwarning = custom_formatwarning # type: ignore[assignment]
2626

2727

2828
class Project:
2929
"""
3030
A Roboflow Project.
3131
"""
3232

33-
def __init__(self, api_key: str, a_project: str, model_format: str = None):
33+
def __init__(self, api_key: str, a_project: dict, model_format: str | None = None):
3434
"""
3535
Create a Project object that represents a Project associated with a Workspace.
3636
@@ -283,7 +283,7 @@ def train(
283283

284284
return new_model
285285

286-
def version(self, version_number: int, local: str = None):
286+
def version(self, version_number: int, local: str | None = None):
287287
"""
288288
Retrieves information about a specific version and returns a Version() object.
289289
@@ -311,6 +311,7 @@ def version(self, version_number: int, local: str = None):
311311
local=None,
312312
workspace="",
313313
project="",
314+
public=True,
314315
)
315316

316317
version_info = self.get_version_information()
@@ -356,13 +357,13 @@ def check_valid_image(self, image_path: str):
356357

357358
def upload(
358359
self,
359-
image_path: str = None,
360-
annotation_path: str = None,
360+
image_path: str,
361+
annotation_path: str | None = None,
361362
hosted_image: bool = False,
362-
image_id: str = None,
363+
image_id: str | None = None,
363364
split: str = "train",
364365
num_retry_uploads: int = 0,
365-
batch_name: str = None,
366+
batch_name: str | None = None,
366367
tag_names: list = [],
367368
is_prediction: bool = False,
368369
**kwargs,
@@ -548,15 +549,15 @@ def _annotation_params(self, annotation_path):
548549

549550
def search(
550551
self,
551-
like_image: str = None,
552-
prompt: str = None,
552+
like_image: str | None = None,
553+
prompt: str | None = None,
553554
offset: int = 0,
554555
limit: int = 100,
555-
tag: str = None,
556-
class_name: str = None,
557-
in_dataset: str = None,
556+
tag: str | None = None,
557+
class_name: str | None = None,
558+
in_dataset: str | None = None,
558559
batch: bool = False,
559-
batch_id: str = None,
560+
batch_id: str | None = None,
560561
fields: list = ["id", "created", "name", "labels"],
561562
):
562563
"""
@@ -586,7 +587,7 @@ def search(
586587
587588
>>> results = project.search(query="cat", limit=10)
588589
""" # noqa: E501 // docs
589-
payload = {}
590+
payload: dict[str, str | int | list[str]] = {}
590591

591592
if like_image is not None:
592593
payload["like_image"] = like_image
@@ -626,15 +627,15 @@ def search(
626627

627628
def search_all(
628629
self,
629-
like_image: str = None,
630-
prompt: str = None,
630+
like_image: str | None = None,
631+
prompt: str | None = None,
631632
offset: int = 0,
632633
limit: int = 100,
633-
tag: str = None,
634-
class_name: str = None,
635-
in_dataset: str = None,
634+
tag: str | None = None,
635+
class_name: str | None = None,
636+
in_dataset: str | None = None,
636637
batch: bool = False,
637-
batch_id: str = None,
638+
batch_id: str | None = None,
638639
fields: list = ["id", "created"],
639640
):
640641
"""

roboflow/core/version.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def live_plot(epochs, mAP, loss, title=""):
359359
plt.show()
360360

361361
first_graph_write = False
362-
previous_epochs = []
362+
previous_epochs: np.ndarray | list = []
363363
num_machine_spin_dots = []
364364

365365
while status == "training" or status == "running":
@@ -381,6 +381,10 @@ def live_plot(epochs, mAP, loss, title=""):
381381
write_line(line="Training failed")
382382
break
383383

384+
epochs: np.ndarray | list
385+
mAP: np.ndarray | list
386+
loss: np.ndarray | list
387+
384388
if "roboflow-train" in models.keys():
385389
# training has started
386390
epochs = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
@@ -452,7 +456,7 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
452456
import ultralytics
453457

454458
except ImportError:
455-
raise (
459+
raise Exception(
456460
"The ultralytics python package is required to deploy yolov8"
457461
" models. Please install it with `pip install ultralytics`"
458462
)
@@ -465,7 +469,7 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
465469
import ultralytics
466470

467471
except ImportError:
468-
raise (
472+
raise Exception(
469473
"The ultralytics python package is required to deploy yolov10"
470474
" models. Please install it with `pip install ultralytics`"
471475
)
@@ -474,7 +478,7 @@ def deploy(self, model_type: str, model_path: str, filename: str = "weights/best
474478
try:
475479
import torch
476480
except ImportError:
477-
raise (
481+
raise Exception(
478482
"The torch python package is required to deploy yolov5 models."
479483
" Please install it with `pip install torch`"
480484
)
@@ -619,7 +623,7 @@ def deploy_yolonas(self, model_type: str, model_path: str, filename: str = "weig
619623
try:
620624
import torch
621625
except ImportError:
622-
raise (
626+
raise Exception(
623627
"The torch python package is required to deploy yolonas models."
624628
" Please install it with `pip install torch`"
625629
)

roboflow/core/workspace.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def create_project(self, project_name, project_type, project_license, annotation
117117

118118
return self.project(r.json()["id"].split("/")[-1])
119119

120-
def clip_compare(self, dir: str = "", image_ext: str = ".png", target_image: str = "") -> dict:
120+
def clip_compare(self, dir: str = "", image_ext: str = ".png", target_image: str = "") -> list[dict]:
121121
"""
122122
Compare all images in a directory to a target image using CLIP
123123
@@ -127,6 +127,7 @@ def clip_compare(self, dir: str = "", image_ext: str = ".png", target_image: str
127127
target_image (str): name reference for target image to compare individual images from directory against
128128
129129
Returns:
130+
# TODO: fix docs
130131
dict: a key:value mapping of image_name:comparison_score_to_target
131132
""" # noqa: E501 // docs
132133

@@ -135,7 +136,7 @@ def clip_compare(self, dir: str = "", image_ext: str = ".png", target_image: str
135136
# grab all images in a given directory with ext type
136137
for image in glob.glob(f"./{dir}/*{image_ext}"):
137138
# compare image
138-
similarity = clip_encode(image, target_image)
139+
similarity = clip_encode(image, target_image, CLIP_FEATURIZE_URL)
139140
# map image name to similarity score
140141
comparisons.append({image: similarity})
141142
comparisons = sorted(comparisons, key=lambda item: -list(item.values())[0])
@@ -148,7 +149,7 @@ def two_stage(
148149
first_stage_model_version: int = 0,
149150
second_stage_model_name: str = "",
150151
second_stage_model_version: int = 0,
151-
) -> dict:
152+
) -> list[dict]:
152153
"""
153154
For each prediction in a first stage detection, perform detection with the second stage model
154155
@@ -160,6 +161,7 @@ def two_stage(
160161
second_stage_model_version (int): version number for the second stage model
161162
162163
Returns:
164+
# TODO: fix docs
163165
dict: a json obj containing the results of the second stage detection
164166
""" # noqa: E501 // docs
165167
results = []
@@ -218,7 +220,7 @@ def two_stage_ocr(
218220
image: str = "",
219221
first_stage_model_name: str = "",
220222
first_stage_model_version: int = 0,
221-
) -> dict:
223+
) -> list[dict]:
222224
"""
223225
For each prediction in the first stage object detection, perform OCR as second stage.
224226
@@ -228,6 +230,7 @@ def two_stage_ocr(
228230
first_stage_model_version (int): version number for the first stage model
229231
230232
Returns:
233+
# TODO: fix docs
231234
dict: a json obj containing the results of the second stage detection
232235
""" # noqa: E501 // docs
233236
results = []

roboflow/models/classification.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def __init__(
2323
self,
2424
api_key: str,
2525
id: str,
26-
name: str = None,
27-
version: int = None,
26+
name: str | None = None,
27+
version: int | None = None,
2828
local: bool = False,
29-
colors: dict = None,
30-
preprocessing: dict = None,
29+
colors: dict | None = None,
30+
preprocessing: dict | None = None,
3131
):
3232
"""
3333
Create a ClassificationModel object through which you can run inference.
@@ -59,7 +59,7 @@ def __init__(
5959
self.preprocessing = {} if preprocessing is None else preprocessing
6060

6161
if local:
62-
print("initalizing local classification model hosted at :" + local)
62+
print(f"initalizing local classification model hosted at : {local}")
6363
self.base_url = local
6464

6565
def predict(self, image_path, hosted=False):

roboflow/models/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def predict_video(
140140
fps: int = 5,
141141
additional_models: list = [],
142142
prediction_type: str = "batch-video",
143-
) -> List[str]:
143+
) -> tuple[str, str, str | None]:
144144
"""
145145
Infers detections based on image from specified model and image path.
146146
@@ -280,7 +280,7 @@ def predict_video(
280280

281281
return job_id, signed_url, signed_url_expires
282282

283-
def poll_for_video_results(self, job_id: str = None) -> dict:
283+
def poll_for_video_results(self, job_id: str | None = None) -> dict:
284284
"""
285285
Polls the Roboflow API to check if video inference is complete.
286286

roboflow/models/instance_segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def __init__(
1212
self,
1313
api_key: str,
1414
version_id: str,
15-
colors: dict = None,
16-
preprocessing: dict = None,
15+
colors: dict | None = None,
16+
preprocessing: dict | None = None,
1717
local: bool = None,
1818
):
1919
"""

roboflow/models/keypoint_detection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(
2323
self,
2424
api_key: str,
2525
id: str,
26-
name: str = None,
27-
version: int = None,
26+
name: str | None = None,
27+
version: int | None = None,
2828
local: bool = False,
2929
):
3030
"""
@@ -54,7 +54,7 @@ def __init__(
5454
self.__generate_url()
5555

5656
if local:
57-
print("initalizing local keypoint detection model hosted at :" + local)
57+
print(f"initalizing local keypoint detection model hosted at : {local}")
5858
self.base_url = local
5959

6060
def predict(self, image_path, hosted=False):

roboflow/models/video.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def predict(
6161
video_path: str,
6262
inference_type: str,
6363
fps: int = 5,
64-
additional_models: list = None,
65-
) -> List[str, str]:
64+
additional_models: list | None = None,
65+
) -> tuple[str, str]:
6666
"""
6767
Infers detections based on image from specified model and image path.
6868
@@ -86,11 +86,14 @@ def predict(
8686
>>> prediction = model.predict("video.mp4", fps=5, inference_type="object-detection")
8787
""" # noqa: E501 // docs
8888

89-
url = urljoin(API_URL, "/video_upload_signed_url/?api_key=", self.__api_key)
89+
url = urljoin(API_URL, f"/video_upload_signed_url/?api_key={self.__api_key}")
9090

9191
if fps > 30:
9292
raise Exception("FPS must be less than or equal to 30.")
9393

94+
if additional_models is None:
95+
additional_models = []
96+
9497
for model in additional_models:
9598
if model not in SUPPORTED_ADDITIONAL_MODELS:
9699
raise Exception(f"Model {model} is not supported for video inference.")
@@ -115,7 +118,7 @@ def predict(
115118

116119
print("Uploaded video to signed url: " + signed_url)
117120

118-
url = urljoin(API_URL, "/videoinfer/?api_key=", self.__api_key)
121+
url = urljoin(API_URL, f"/videoinfer/?api_key={self.__api_key}")
119122

120123
models = [
121124
{
@@ -138,7 +141,7 @@ def predict(
138141

139142
return job_id, signed_url
140143

141-
def poll_for_results(self, job_id: str = None) -> dict:
144+
def poll_for_results(self, job_id: str | None = None) -> dict:
142145
"""
143146
Polls the Roboflow API to check if video inference is complete.
144147
@@ -162,7 +165,7 @@ def poll_for_results(self, job_id: str = None) -> dict:
162165
if job_id is None:
163166
job_id = self.job_id
164167

165-
url = urljoin(API_URL, "/videoinfer/?api_key=", self.__api_key, "&job_id=", self.job_id)
168+
url = urljoin(API_URL, f"/videoinfer/?api_key={self.__api_key}&job_id={self.job_id}")
166169

167170
try:
168171
response = requests.get(url, headers={"Content-Type": "application/json"})
@@ -216,7 +219,7 @@ def poll_until_results(self, job_id) -> dict:
216219
attempts = 0
217220

218221
while True:
219-
response = self.poll_for_response()
222+
response = self.poll_for_results()
220223

221224
attempts += 1
222225

roboflow/util/prediction.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,12 @@ def create_prediction_group(json_response, image_path, prediction_type, image_di
500500
colors=colors,
501501
)
502502
prediction_list.append(prediction)
503-
img_dims = image_dims
504503
elif prediction_type == CLASSIFICATION_MODEL:
505504
prediction = Prediction(json_response, image_path, prediction_type, colors=colors)
506505
prediction_list.append(prediction)
507-
img_dims = image_dims
508506
elif prediction_type == SEMANTIC_SEGMENTATION_MODEL:
509507
prediction = Prediction(json_response, image_path, prediction_type, colors=colors)
510508
prediction_list.append(prediction)
511-
img_dims = image_dims
512509

513510
# Seperate list and return as a prediction group
514-
return PredictionGroup(img_dims, image_path, *prediction_list)
511+
return PredictionGroup(image_dims, image_path, *prediction_list)

0 commit comments

Comments
 (0)