1
1
import io
2
+ import json
3
+ import os
4
+ import time
2
5
import urllib
6
+ from typing import List
7
+ from urllib .parse import urljoin
3
8
4
9
import requests
5
10
from PIL import Image
6
11
from requests_toolbelt .multipart .encoder import MultipartEncoder
7
12
13
+ from roboflow .config import API_URL
8
14
from roboflow .util .image_utils import validate_image_path
9
15
from roboflow .util .prediction import PredictionGroup
10
16
17
+ SUPPORTED_ROBOFLOW_MODELS = ["batch-video" ]
18
+
19
+ SUPPORTED_ADDITIONAL_MODELS = {
20
+ "clip" : {
21
+ "model_id" : "clip" ,
22
+ "model_version" : "1" ,
23
+ "inference_type" : "clip-embed-image" ,
24
+ },
25
+ "gaze" : {
26
+ "model_id" : "gaze" ,
27
+ "model_version" : "1" ,
28
+ "inference_type" : "gaze-detection" ,
29
+ },
30
+ }
31
+
11
32
12
33
class InferenceModel :
13
34
def __init__ (
@@ -25,13 +46,15 @@ def __init__(
25
46
api_key (str): private roboflow api key
26
47
version_id (str): the ID of the dataset version to use for inference
27
48
"""
49
+
28
50
self .__api_key = api_key
29
51
self .id = version_id
30
52
31
- version_info = self .id .rsplit ("/" )
32
- self .dataset_id = version_info [1 ]
33
- self .version = version_info [2 ]
34
- self .colors = {} if colors is None else colors
53
+ if version_id != "BASE_MODEL" :
54
+ version_info = self .id .rsplit ("/" )
55
+ self .dataset_id = version_info [1 ]
56
+ self .version = version_info [2 ]
57
+ self .colors = {} if colors is None else colors
35
58
36
59
def __get_image_params (self , image_path ):
37
60
"""
@@ -111,3 +134,238 @@ def predict(self, image_path, prediction_type=None, **kwargs):
111
134
image_dims = image_dims ,
112
135
colors = self .colors ,
113
136
)
137
+
138
+ def predict_video (
139
+ self ,
140
+ video_path : str ,
141
+ fps : int = 5 ,
142
+ additional_models : list = [],
143
+ prediction_type : str = "batch-video" ,
144
+ ) -> List [str ]:
145
+ """
146
+ Infers detections based on image from specified model and image path.
147
+
148
+ Args:
149
+ video_path (str): path to the video you'd like to perform prediction on
150
+ prediction_type (str): type of the model to run
151
+ fps (int): frames per second to run inference
152
+
153
+ Returns:
154
+ A list of the signed url and job id
155
+
156
+ Example:
157
+ >>> import roboflow
158
+
159
+ >>> rf = roboflow.Roboflow(api_key="")
160
+
161
+ >>> project = rf.workspace().project("PROJECT_ID")
162
+
163
+ >>> model = project.version("1").model
164
+
165
+ >>> job_id, signed_url, signed_url_expires = model.predict_video("video.mp4", fps=5, inference_type="object-detection")
166
+ """
167
+
168
+ signed_url_expires = None
169
+
170
+ url = urljoin (API_URL , "/video_upload_signed_url?api_key=" + self .__api_key )
171
+
172
+ if fps > 5 :
173
+ raise Exception ("FPS must be less than or equal to 5." )
174
+
175
+ for model in additional_models :
176
+ if model not in SUPPORTED_ADDITIONAL_MODELS :
177
+ raise Exception (f"Model { model } is not supported for video inference." )
178
+
179
+ if prediction_type not in SUPPORTED_ROBOFLOW_MODELS :
180
+ raise Exception (f"{ prediction_type } is not supported for video inference." )
181
+
182
+ model_class = self .__class__ .__name__
183
+
184
+ if model_class == "ObjectDetectionModel" :
185
+ self .type = "object-detection"
186
+ elif model_class == "ClassificationModel" :
187
+ self .type = "classification"
188
+ elif model_class == "InstanceSegmentationModel" :
189
+ self .type = "instance-segmentation"
190
+ elif model_class == "GazeModel" :
191
+ self .type = "gaze-detection"
192
+ elif model_class == "CLIPModel" :
193
+ self .type = "clip-embed-image"
194
+ else :
195
+ raise Exception ("Model type not supported for video inference." )
196
+
197
+ payload = json .dumps (
198
+ {
199
+ "file_name" : os .path .basename (video_path ),
200
+ }
201
+ )
202
+
203
+ if not video_path .startswith (("http://" , "https://" )):
204
+ headers = {"Content-Type" : "application/json" }
205
+
206
+ try :
207
+ response = requests .request ("POST" , url , headers = headers , data = payload )
208
+ except Exception as e :
209
+ raise Exception (f"Error uploading video: { e } " )
210
+
211
+ if not response .ok :
212
+ raise Exception (f"Error uploading video: { response .text } " )
213
+
214
+ signed_url = response .json ()["signed_url" ]
215
+
216
+ signed_url_expires = (
217
+ signed_url .split ("&X-Goog-Expires" )[1 ].split ("&" )[0 ].strip ("=" )
218
+ )
219
+
220
+ # make a POST request to the signed URL
221
+ headers = {"Content-Type" : "application/octet-stream" }
222
+
223
+ try :
224
+ with open (video_path , "rb" ) as f :
225
+ video_data = f .read ()
226
+ except Exception as e :
227
+ raise Exception (f"Error reading video: { e } " )
228
+
229
+ try :
230
+ result = requests .put (signed_url , data = video_data , headers = headers )
231
+ except Exception as e :
232
+ raise Exception (f"There was an error uploading the video: { e } " )
233
+
234
+ if not result .ok :
235
+ raise Exception (
236
+ f"There was an error uploading the video: { result .text } "
237
+ )
238
+ else :
239
+ signed_url = video_path
240
+
241
+ url = urljoin (API_URL , "/videoinfer/?api_key=" + self .__api_key )
242
+
243
+ if model_class in ("CLIPModel" , "GazeModel" ):
244
+ if model_class == "CLIPModel" :
245
+ model = "clip"
246
+ else :
247
+ model = "gaze"
248
+
249
+ models = [
250
+ {
251
+ "model_id" : SUPPORTED_ADDITIONAL_MODELS [model ]["model_id" ],
252
+ "model_version" : SUPPORTED_ADDITIONAL_MODELS [model ][
253
+ "model_version"
254
+ ],
255
+ "inference_type" : SUPPORTED_ADDITIONAL_MODELS [model ][
256
+ "inference_type"
257
+ ],
258
+ }
259
+ ]
260
+
261
+ for model in additional_models :
262
+ models .append (SUPPORTED_ADDITIONAL_MODELS [model ])
263
+
264
+ payload = json .dumps (
265
+ {"input_url" : signed_url , "infer_fps" : 5 , "models" : models }
266
+ )
267
+
268
+ headers = {"Content-Type" : "application/json" }
269
+
270
+ try :
271
+ response = requests .request ("POST" , url , headers = headers , data = payload )
272
+ except Exception as e :
273
+ raise Exception (f"Error starting video inference: { e } " )
274
+
275
+ if not response .ok :
276
+ raise Exception (f"Error starting video inference: { response .text } " )
277
+
278
+ job_id = response .json ()["job_id" ]
279
+
280
+ self .job_id = job_id
281
+
282
+ return job_id , signed_url , signed_url_expires
283
+
284
+ def poll_for_video_results (self , job_id : str = None ) -> dict :
285
+ """
286
+ Polls the Roboflow API to check if video inference is complete.
287
+
288
+ Returns:
289
+ Inference results as a dict
290
+
291
+ Example:
292
+ >>> import roboflow
293
+
294
+ >>> rf = roboflow.Roboflow(api_key="")
295
+
296
+ >>> project = rf.workspace().project("PROJECT_ID")
297
+
298
+ >>> model = project.version("1").model
299
+
300
+ >>> prediction = model.predict("video.mp4")
301
+
302
+ >>> results = model.poll_for_video_results()
303
+ """
304
+
305
+ if job_id is None :
306
+ job_id = self .job_id
307
+
308
+ url = urljoin (
309
+ API_URL , "/videoinfer/?api_key=" + self .__api_key + "&job_id=" + self .job_id
310
+ )
311
+
312
+ try :
313
+ response = requests .get (url , headers = {"Content-Type" : "application/json" })
314
+ except Exception as e :
315
+ raise Exception (f"Error getting video inference results: { e } " )
316
+
317
+ if not response .ok :
318
+ raise Exception (f"Error getting video inference results: { response .text } " )
319
+
320
+ data = response .json ()
321
+
322
+ if data .get ("status" ) != 0 :
323
+ return {}
324
+
325
+ output_signed_url = data ["output_signed_url" ]
326
+
327
+ inference_data = requests .get (
328
+ output_signed_url , headers = {"Content-Type" : "application/json" }
329
+ )
330
+
331
+ # frame_offset and model name are top-level keys
332
+ return inference_data .json ()
333
+
334
+ def poll_until_video_results (self , job_id ) -> dict :
335
+ """
336
+ Polls the Roboflow API to check if video inference is complete.
337
+
338
+ When inference is complete, the results are returned.
339
+
340
+ Returns:
341
+ Inference results as a dict
342
+
343
+ Example:
344
+ >>> import roboflow
345
+
346
+ >>> rf = roboflow.Roboflow(api_key="")
347
+
348
+ >>> project = rf.workspace().project("PROJECT_ID")
349
+
350
+ >>> model = project.version("1").model
351
+
352
+ >>> prediction = model.predict("video.mp4")
353
+
354
+ >>> results = model.poll_until_results()
355
+ """
356
+ if job_id is None :
357
+ job_id = self .job_id
358
+
359
+ attempts = 0
360
+
361
+ while True :
362
+ print (f"({ attempts * 60 } s): Checking for inference results" )
363
+
364
+ response = self .poll_for_video_results ()
365
+
366
+ time .sleep (60 )
367
+
368
+ attempts += 1
369
+
370
+ if response != {}:
371
+ return response
0 commit comments