Skip to content

Commit 080d588

Browse files
authored
Merge pull request #200 from roboflow/add-video-inference
Feature: Video Inference
2 parents ed84065 + aa250cc commit 080d588

File tree

8 files changed

+542
-6
lines changed

8 files changed

+542
-6
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ supervision
1616
urllib3>=1.26.6
1717
tqdm>=4.41.0
1818
PyYAML>=5.3.1
19-
requests_toolbelt
19+
requests_toolbelt
20+
python-magic

roboflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from roboflow.config import API_URL, APP_URL, DEMO_KEYS, load_roboflow_api_key
1111
from roboflow.core.project import Project
1212
from roboflow.core.workspace import Workspace
13+
from roboflow.models import CLIPModel, GazeModel
1314
from roboflow.util.general import write_line
1415

1516
__version__ = "1.1.7"

roboflow/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .clip import CLIPModel
2+
from .gaze import GazeModel

roboflow/models/clip.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .inference import InferenceModel
2+
3+
4+
class CLIPModel(InferenceModel):
5+
"""
6+
Run inference on CLIP, hosted on Roboflow.
7+
"""
8+
9+
def __init__(self, api_key: str):
10+
"""
11+
Initialize a CLIP model.
12+
13+
Args:
14+
api_key: Your Roboflow API key.
15+
"""
16+
super().__init__(api_key=api_key, version_id="BASE_MODEL")

roboflow/models/gaze.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .inference import InferenceModel
2+
3+
4+
class GazeModel(InferenceModel):
5+
"""
6+
Run inference on a gaze detection model, hosted on Roboflow.
7+
"""
8+
9+
def __init__(self, api_key: str):
10+
"""
11+
Initialize a CLIP model.
12+
13+
Args:
14+
api_key: Your Roboflow API key.
15+
"""
16+
super().__init__(api_key=api_key)

roboflow/models/inference.py

Lines changed: 262 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
11
import io
2+
import json
3+
import os
4+
import time
25
import urllib
6+
from typing import List
7+
from urllib.parse import urljoin
38

49
import requests
510
from PIL import Image
611
from requests_toolbelt.multipart.encoder import MultipartEncoder
712

13+
from roboflow.config import API_URL
814
from roboflow.util.image_utils import validate_image_path
915
from roboflow.util.prediction import PredictionGroup
1016

17+
SUPPORTED_ROBOFLOW_MODELS = ["batch-video"]
18+
19+
SUPPORTED_ADDITIONAL_MODELS = {
20+
"clip": {
21+
"model_id": "clip",
22+
"model_version": "1",
23+
"inference_type": "clip-embed-image",
24+
},
25+
"gaze": {
26+
"model_id": "gaze",
27+
"model_version": "1",
28+
"inference_type": "gaze-detection",
29+
},
30+
}
31+
1132

1233
class InferenceModel:
1334
def __init__(
@@ -25,13 +46,15 @@ def __init__(
2546
api_key (str): private roboflow api key
2647
version_id (str): the ID of the dataset version to use for inference
2748
"""
49+
2850
self.__api_key = api_key
2951
self.id = version_id
3052

31-
version_info = self.id.rsplit("/")
32-
self.dataset_id = version_info[1]
33-
self.version = version_info[2]
34-
self.colors = {} if colors is None else colors
53+
if version_id != "BASE_MODEL":
54+
version_info = self.id.rsplit("/")
55+
self.dataset_id = version_info[1]
56+
self.version = version_info[2]
57+
self.colors = {} if colors is None else colors
3558

3659
def __get_image_params(self, image_path):
3760
"""
@@ -111,3 +134,238 @@ def predict(self, image_path, prediction_type=None, **kwargs):
111134
image_dims=image_dims,
112135
colors=self.colors,
113136
)
137+
138+
def predict_video(
139+
self,
140+
video_path: str,
141+
fps: int = 5,
142+
additional_models: list = [],
143+
prediction_type: str = "batch-video",
144+
) -> List[str]:
145+
"""
146+
Infers detections based on image from specified model and image path.
147+
148+
Args:
149+
video_path (str): path to the video you'd like to perform prediction on
150+
prediction_type (str): type of the model to run
151+
fps (int): frames per second to run inference
152+
153+
Returns:
154+
A list of the signed url and job id
155+
156+
Example:
157+
>>> import roboflow
158+
159+
>>> rf = roboflow.Roboflow(api_key="")
160+
161+
>>> project = rf.workspace().project("PROJECT_ID")
162+
163+
>>> model = project.version("1").model
164+
165+
>>> job_id, signed_url, signed_url_expires = model.predict_video("video.mp4", fps=5, inference_type="object-detection")
166+
"""
167+
168+
signed_url_expires = None
169+
170+
url = urljoin(API_URL, "/video_upload_signed_url?api_key=" + self.__api_key)
171+
172+
if fps > 5:
173+
raise Exception("FPS must be less than or equal to 5.")
174+
175+
for model in additional_models:
176+
if model not in SUPPORTED_ADDITIONAL_MODELS:
177+
raise Exception(f"Model {model} is not supported for video inference.")
178+
179+
if prediction_type not in SUPPORTED_ROBOFLOW_MODELS:
180+
raise Exception(f"{prediction_type} is not supported for video inference.")
181+
182+
model_class = self.__class__.__name__
183+
184+
if model_class == "ObjectDetectionModel":
185+
self.type = "object-detection"
186+
elif model_class == "ClassificationModel":
187+
self.type = "classification"
188+
elif model_class == "InstanceSegmentationModel":
189+
self.type = "instance-segmentation"
190+
elif model_class == "GazeModel":
191+
self.type = "gaze-detection"
192+
elif model_class == "CLIPModel":
193+
self.type = "clip-embed-image"
194+
else:
195+
raise Exception("Model type not supported for video inference.")
196+
197+
payload = json.dumps(
198+
{
199+
"file_name": os.path.basename(video_path),
200+
}
201+
)
202+
203+
if not video_path.startswith(("http://", "https://")):
204+
headers = {"Content-Type": "application/json"}
205+
206+
try:
207+
response = requests.request("POST", url, headers=headers, data=payload)
208+
except Exception as e:
209+
raise Exception(f"Error uploading video: {e}")
210+
211+
if not response.ok:
212+
raise Exception(f"Error uploading video: {response.text}")
213+
214+
signed_url = response.json()["signed_url"]
215+
216+
signed_url_expires = (
217+
signed_url.split("&X-Goog-Expires")[1].split("&")[0].strip("=")
218+
)
219+
220+
# make a POST request to the signed URL
221+
headers = {"Content-Type": "application/octet-stream"}
222+
223+
try:
224+
with open(video_path, "rb") as f:
225+
video_data = f.read()
226+
except Exception as e:
227+
raise Exception(f"Error reading video: {e}")
228+
229+
try:
230+
result = requests.put(signed_url, data=video_data, headers=headers)
231+
except Exception as e:
232+
raise Exception(f"There was an error uploading the video: {e}")
233+
234+
if not result.ok:
235+
raise Exception(
236+
f"There was an error uploading the video: {result.text}"
237+
)
238+
else:
239+
signed_url = video_path
240+
241+
url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key)
242+
243+
if model_class in ("CLIPModel", "GazeModel"):
244+
if model_class == "CLIPModel":
245+
model = "clip"
246+
else:
247+
model = "gaze"
248+
249+
models = [
250+
{
251+
"model_id": SUPPORTED_ADDITIONAL_MODELS[model]["model_id"],
252+
"model_version": SUPPORTED_ADDITIONAL_MODELS[model][
253+
"model_version"
254+
],
255+
"inference_type": SUPPORTED_ADDITIONAL_MODELS[model][
256+
"inference_type"
257+
],
258+
}
259+
]
260+
261+
for model in additional_models:
262+
models.append(SUPPORTED_ADDITIONAL_MODELS[model])
263+
264+
payload = json.dumps(
265+
{"input_url": signed_url, "infer_fps": 5, "models": models}
266+
)
267+
268+
headers = {"Content-Type": "application/json"}
269+
270+
try:
271+
response = requests.request("POST", url, headers=headers, data=payload)
272+
except Exception as e:
273+
raise Exception(f"Error starting video inference: {e}")
274+
275+
if not response.ok:
276+
raise Exception(f"Error starting video inference: {response.text}")
277+
278+
job_id = response.json()["job_id"]
279+
280+
self.job_id = job_id
281+
282+
return job_id, signed_url, signed_url_expires
283+
284+
def poll_for_video_results(self, job_id: str = None) -> dict:
285+
"""
286+
Polls the Roboflow API to check if video inference is complete.
287+
288+
Returns:
289+
Inference results as a dict
290+
291+
Example:
292+
>>> import roboflow
293+
294+
>>> rf = roboflow.Roboflow(api_key="")
295+
296+
>>> project = rf.workspace().project("PROJECT_ID")
297+
298+
>>> model = project.version("1").model
299+
300+
>>> prediction = model.predict("video.mp4")
301+
302+
>>> results = model.poll_for_video_results()
303+
"""
304+
305+
if job_id is None:
306+
job_id = self.job_id
307+
308+
url = urljoin(
309+
API_URL, "/videoinfer/?api_key=" + self.__api_key + "&job_id=" + self.job_id
310+
)
311+
312+
try:
313+
response = requests.get(url, headers={"Content-Type": "application/json"})
314+
except Exception as e:
315+
raise Exception(f"Error getting video inference results: {e}")
316+
317+
if not response.ok:
318+
raise Exception(f"Error getting video inference results: {response.text}")
319+
320+
data = response.json()
321+
322+
if data.get("status") != 0:
323+
return {}
324+
325+
output_signed_url = data["output_signed_url"]
326+
327+
inference_data = requests.get(
328+
output_signed_url, headers={"Content-Type": "application/json"}
329+
)
330+
331+
# frame_offset and model name are top-level keys
332+
return inference_data.json()
333+
334+
def poll_until_video_results(self, job_id) -> dict:
335+
"""
336+
Polls the Roboflow API to check if video inference is complete.
337+
338+
When inference is complete, the results are returned.
339+
340+
Returns:
341+
Inference results as a dict
342+
343+
Example:
344+
>>> import roboflow
345+
346+
>>> rf = roboflow.Roboflow(api_key="")
347+
348+
>>> project = rf.workspace().project("PROJECT_ID")
349+
350+
>>> model = project.version("1").model
351+
352+
>>> prediction = model.predict("video.mp4")
353+
354+
>>> results = model.poll_until_results()
355+
"""
356+
if job_id is None:
357+
job_id = self.job_id
358+
359+
attempts = 0
360+
361+
while True:
362+
print(f"({attempts * 60}s): Checking for inference results")
363+
364+
response = self.poll_for_video_results()
365+
366+
time.sleep(60)
367+
368+
attempts += 1
369+
370+
if response != {}:
371+
return response

roboflow/models/object_detection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from PIL import Image
1515

1616
from roboflow.config import API_URL, OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL
17+
from roboflow.models.inference import InferenceModel
1718
from roboflow.util.image_utils import check_image_url
1819
from roboflow.util.prediction import PredictionGroup
1920
from roboflow.util.versions import print_warn_for_wrong_dependencies_versions
2021

2122

22-
class ObjectDetectionModel:
23+
class ObjectDetectionModel(InferenceModel):
2324
"""
2425
Run inference on an object detection model hosted on Roboflow or served through Roboflow Inference.
2526
"""
@@ -67,6 +68,7 @@ def __init__(
6768
"""
6869
# Instantiate different API URL parameters
6970
# To be moved to predict
71+
super(ObjectDetectionModel, self).__init__(api_key, id)
7072
self.__api_key = api_key
7173
self.id = id
7274
self.name = name

0 commit comments

Comments
 (0)