Skip to content

Fix many mypy errors #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ convention = "google"
max-args = 20

[tool.mypy]
python_version = "3.8"
exclude = [
"^build/"
]
Expand Down
5 changes: 3 additions & 2 deletions roboflow/adapters/rfapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import urllib
from typing import Optional

import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
Expand Down Expand Up @@ -43,8 +44,8 @@ def upload_image(
split: str = "train",
batch_name: str = DEFAULT_BATCH_NAME,
tag_names: list = [],
sequence_number: int = None,
sequence_size: int = None,
sequence_number: Optional[int] = None,
sequence_size: Optional[int] = None,
**kwargs,
):
"""
Expand Down
41 changes: 21 additions & 20 deletions roboflow/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import time
import warnings
from typing import Dict, List, Optional, Union

import requests
from PIL import Image, UnidentifiedImageError
Expand All @@ -22,15 +23,15 @@ def custom_formatwarning(msg, *args, **kwargs):
return str(msg) + "\n"


warnings.formatwarning = custom_formatwarning
warnings.formatwarning = custom_formatwarning # type: ignore[assignment]


class Project:
"""
A Roboflow Project.
"""

def __init__(self, api_key: str, a_project: str, model_format: str = None):
def __init__(self, api_key: str, a_project: dict, model_format: Optional[str] = None):
"""
Create a Project object that represents a Project associated with a Workspace.

Expand Down Expand Up @@ -283,7 +284,7 @@ def train(

return new_model

def version(self, version_number: int, local: str = None):
def version(self, version_number: int, local: Optional[str] = None):
"""
Retrieves information about a specific version and returns a Version() object.

Expand Down Expand Up @@ -357,13 +358,13 @@ def check_valid_image(self, image_path: str):

def upload(
self,
image_path: str = None,
annotation_path: str = None,
image_path: str,
annotation_path: Optional[str] = None,
hosted_image: bool = False,
image_id: str = None,
image_id: Optional[str] = None,
split: str = "train",
num_retry_uploads: int = 0,
batch_name: str = None,
batch_name: Optional[str] = None,
tag_names: list = [],
is_prediction: bool = False,
**kwargs,
Expand Down Expand Up @@ -549,15 +550,15 @@ def _annotation_params(self, annotation_path):

def search(
self,
like_image: str = None,
prompt: str = None,
like_image: Optional[str] = None,
prompt: Optional[str] = None,
offset: int = 0,
limit: int = 100,
tag: str = None,
class_name: str = None,
in_dataset: str = None,
tag: Optional[str] = None,
class_name: Optional[str] = None,
in_dataset: Optional[str] = None,
batch: bool = False,
batch_id: str = None,
batch_id: Optional[str] = None,
fields: list = ["id", "created", "name", "labels"],
):
"""
Expand Down Expand Up @@ -587,7 +588,7 @@ def search(

>>> results = project.search(query="cat", limit=10)
""" # noqa: E501 // docs
payload = {}
payload: Dict[str, Union[str, int, List[str]]] = {}

if like_image is not None:
payload["like_image"] = like_image
Expand Down Expand Up @@ -627,15 +628,15 @@ def search(

def search_all(
self,
like_image: str = None,
prompt: str = None,
like_image: Optional[str] = None,
prompt: Optional[str] = None,
offset: int = 0,
limit: int = 100,
tag: str = None,
class_name: str = None,
in_dataset: str = None,
tag: Optional[str] = None,
class_name: Optional[str] = None,
in_dataset: Optional[str] = None,
batch: bool = False,
batch_id: str = None,
batch_id: Optional[str] = None,
fields: list = ["id", "created"],
):
"""
Expand Down
7 changes: 6 additions & 1 deletion roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import zipfile
from importlib import import_module
from typing import Union

import numpy as np
import requests
Expand Down Expand Up @@ -359,7 +360,7 @@ def live_plot(epochs, mAP, loss, title=""):
plt.show()

first_graph_write = False
previous_epochs = []
previous_epochs: Union[np.ndarray, list] = []
num_machine_spin_dots = []

while status == "training" or status == "running":
Expand All @@ -381,6 +382,10 @@ def live_plot(epochs, mAP, loss, title=""):
write_line(line="Training failed")
break

epochs: Union[np.ndarray, list]
mAP: Union[np.ndarray, list]
loss: Union[np.ndarray, list]

if "roboflow-train" in models.keys():
# training has started
epochs = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
Expand Down
10 changes: 7 additions & 3 deletions roboflow/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import sys
from typing import List

import numpy as np
import requests
Expand Down Expand Up @@ -117,7 +118,7 @@ def create_project(self, project_name, project_type, project_license, annotation

return self.project(r.json()["id"].split("/")[-1])

def clip_compare(self, dir: str = "", image_ext: str = ".png", target_image: str = "") -> dict:
def clip_compare(self, dir: str = "", image_ext: str = ".png", target_image: str = "") -> List[dict]:
"""
Compare all images in a directory to a target image using CLIP

Expand All @@ -127,6 +128,7 @@ def clip_compare(self, dir: str = "", image_ext: str = ".png", target_image: str
target_image (str): name reference for target image to compare individual images from directory against

Returns:
# TODO: fix docs
dict: a key:value mapping of image_name:comparison_score_to_target
""" # noqa: E501 // docs

Expand All @@ -148,7 +150,7 @@ def two_stage(
first_stage_model_version: int = 0,
second_stage_model_name: str = "",
second_stage_model_version: int = 0,
) -> dict:
) -> List[dict]:
"""
For each prediction in a first stage detection, perform detection with the second stage model

Expand All @@ -160,6 +162,7 @@ def two_stage(
second_stage_model_version (int): version number for the second stage model

Returns:
# TODO: fix docs
dict: a json obj containing the results of the second stage detection
""" # noqa: E501 // docs
results = []
Expand Down Expand Up @@ -218,7 +221,7 @@ def two_stage_ocr(
image: str = "",
first_stage_model_name: str = "",
first_stage_model_version: int = 0,
) -> dict:
) -> List[dict]:
"""
For each prediction in the first stage object detection, perform OCR as second stage.

Expand All @@ -228,6 +231,7 @@ def two_stage_ocr(
first_stage_model_version (int): version number for the first stage model

Returns:
# TODO: fix docs
dict: a json obj containing the results of the second stage detection
""" # noqa: E501 // docs
results = []
Expand Down
11 changes: 6 additions & 5 deletions roboflow/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import urllib
from typing import Optional

import requests
from PIL import Image
Expand All @@ -23,11 +24,11 @@ def __init__(
self,
api_key: str,
id: str,
name: str = None,
version: int = None,
name: Optional[str] = None,
version: Optional[int] = None,
local: bool = False,
colors: dict = None,
preprocessing: dict = None,
colors: Optional[dict] = None,
preprocessing: Optional[dict] = None,
):
"""
Create a ClassificationModel object through which you can run inference.
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(
self.preprocessing = {} if preprocessing is None else preprocessing

if local:
print("initalizing local classification model hosted at :" + local)
print(f"initalizing local classification model hosted at : {local}")
self.base_url = local

def predict(self, image_path, hosted=False):
Expand Down
6 changes: 3 additions & 3 deletions roboflow/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time
import urllib
from typing import List
from typing import Optional, Tuple
from urllib.parse import urljoin

import requests
Expand Down Expand Up @@ -140,7 +140,7 @@ def predict_video(
fps: int = 5,
additional_models: list = [],
prediction_type: str = "batch-video",
) -> List[str]:
) -> Tuple[str, str, Optional[str]]:
"""
Infers detections based on image from specified model and image path.

Expand Down Expand Up @@ -280,7 +280,7 @@ def predict_video(

return job_id, signed_url, signed_url_expires

def poll_for_video_results(self, job_id: str = None) -> dict:
def poll_for_video_results(self, job_id: Optional[str] = None) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.

Expand Down
6 changes: 4 additions & 2 deletions roboflow/models/instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from roboflow.config import INSTANCE_SEGMENTATION_MODEL, INSTANCE_SEGMENTATION_URL
from roboflow.models.inference import InferenceModel

Expand All @@ -12,8 +14,8 @@ def __init__(
self,
api_key: str,
version_id: str,
colors: dict = None,
preprocessing: dict = None,
colors: Optional[dict] = None,
preprocessing: Optional[dict] = None,
local: bool = None,
):
"""
Expand Down
7 changes: 4 additions & 3 deletions roboflow/models/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import urllib
from typing import Optional

import requests
from PIL import Image
Expand All @@ -23,8 +24,8 @@ def __init__(
self,
api_key: str,
id: str,
name: str = None,
version: int = None,
name: Optional[str] = None,
version: Optional[int] = None,
local: bool = False,
):
"""
Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
self.__generate_url()

if local:
print("initalizing local keypoint detection model hosted at :" + local)
print(f"initalizing local keypoint detection model hosted at : {local}")
self.base_url = local

def predict(self, image_path, hosted=False):
Expand Down
11 changes: 7 additions & 4 deletions roboflow/models/video.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from typing import List
from typing import Optional, Tuple
from urllib.parse import urljoin

import magic
Expand Down Expand Up @@ -61,8 +61,8 @@ def predict(
video_path: str,
inference_type: str,
fps: int = 5,
additional_models: list = None,
) -> List[str, str]:
additional_models: Optional[list] = None,
) -> Tuple[str, str]:
"""
Infers detections based on image from specified model and image path.

Expand Down Expand Up @@ -91,6 +91,9 @@ def predict(
if fps > 30:
raise Exception("FPS must be less than or equal to 30.")

if additional_models is None:
additional_models = []

for model in additional_models:
if model not in SUPPORTED_ADDITIONAL_MODELS:
raise Exception(f"Model {model} is not supported for video inference.")
Expand Down Expand Up @@ -138,7 +141,7 @@ def predict(

return job_id, signed_url

def poll_for_results(self, job_id: str = None) -> dict:
def poll_for_results(self, job_id: Optional[str] = None) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.

Expand Down
5 changes: 1 addition & 4 deletions roboflow/util/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,15 +500,12 @@ def create_prediction_group(json_response, image_path, prediction_type, image_di
colors=colors,
)
prediction_list.append(prediction)
img_dims = image_dims
elif prediction_type == CLASSIFICATION_MODEL:
prediction = Prediction(json_response, image_path, prediction_type, colors=colors)
prediction_list.append(prediction)
img_dims = image_dims
elif prediction_type == SEMANTIC_SEGMENTATION_MODEL:
prediction = Prediction(json_response, image_path, prediction_type, colors=colors)
prediction_list.append(prediction)
img_dims = image_dims

# Seperate list and return as a prediction group
return PredictionGroup(img_dims, image_path, *prediction_list)
return PredictionGroup(image_dims, image_path, *prediction_list)
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

with open("./roboflow/__init__.py", "r") as f:
content = f.read()
version = re.search(r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', content).group(1)
_search_version = re.search(r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', content)
assert _search_version
version = _search_version.group(1)


with open("README.md", "r") as fh:
Expand Down