Skip to content

Commit 2231103

Browse files
committed
Merge branch 'main' into ordered-uploading
2 parents 294137a + 8daa6de commit 2231103

File tree

2 files changed

+42
-53
lines changed

2 files changed

+42
-53
lines changed

roboflow/models/inference.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import requests
1010
from PIL import Image
1111
from requests_toolbelt.multipart.encoder import MultipartEncoder
12+
from tqdm import tqdm
1213

1314
from roboflow.config import API_URL
1415
from roboflow.util.image_utils import validate_image_path
@@ -358,3 +359,43 @@ def poll_until_video_results(self, job_id) -> dict:
358359

359360
if response != {}:
360361
return response
362+
363+
def download(self, format="pt", location="."):
364+
"""
365+
Download the weights associated with a model.
366+
367+
Args:
368+
format (str): The format of the output.
369+
- 'pt': returns a PyTorch weights file
370+
location (str): The location to save the weights file to
371+
"""
372+
supported_formats = ["pt"]
373+
if format not in supported_formats:
374+
raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}")
375+
376+
workspace, project, version = self.id.rsplit("/")
377+
378+
# get pt url
379+
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
380+
381+
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
382+
383+
r.raise_for_status()
384+
385+
pt_weights_url = r.json()["weightsUrl"]
386+
387+
response = requests.get(pt_weights_url, stream=True)
388+
389+
# write the zip file to the desired location
390+
with open(location + "/weights.pt", "wb") as f:
391+
total_length = int(response.headers.get("content-length"))
392+
for chunk in tqdm(
393+
response.iter_content(chunk_size=1024),
394+
desc=f"Downloading weights to {location}/weights.pt",
395+
total=int(total_length / 1024) + 1,
396+
):
397+
if chunk:
398+
f.write(chunk)
399+
f.flush()
400+
401+
return

roboflow/models/object_detection.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
import json
55
import os
66
import random
7-
import sys
87
import urllib
98

109
import cv2
1110
import numpy as np
1211
import requests
1312
from PIL import Image
14-
from tqdm import tqdm
1513

16-
from roboflow.config import API_URL, OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL
14+
from roboflow.config import OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL
1715
from roboflow.models.inference import InferenceModel
1816
from roboflow.util.image_utils import check_image_url
1917
from roboflow.util.prediction import PredictionGroup
@@ -461,56 +459,6 @@ def view(button):
461459
else:
462460
view(stopButton)
463461

464-
def download(self, format="pt", location="."):
465-
"""
466-
Download the weights associated with a model.
467-
468-
Args:
469-
format (str): The format of the output.
470-
- 'pt': returns a PyTorch weights file
471-
location (str): The location to save the weights file to
472-
"""
473-
supported_formats = ["pt"]
474-
if format not in supported_formats:
475-
raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}")
476-
477-
workspace, project, version = self.id.rsplit("/")
478-
479-
# get pt url
480-
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
481-
482-
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
483-
484-
r.raise_for_status()
485-
486-
pt_weights_url = r.json()["weightsUrl"]
487-
488-
def bar_progress(current, total, width=80):
489-
progress_message = (
490-
"Downloading weights to "
491-
+ location
492-
+ "/weights.pt"
493-
+ ": %d%% [%d / %d] bytes" % (current / total * 100, current, total)
494-
)
495-
sys.stdout.write("\r" + progress_message)
496-
sys.stdout.flush()
497-
498-
response = requests.get(pt_weights_url, stream=True)
499-
500-
# write the zip file to the desired location
501-
with open(location + "/weights.pt", "wb") as f:
502-
total_length = int(response.headers.get("content-length"))
503-
for chunk in tqdm(
504-
response.iter_content(chunk_size=1024),
505-
desc=f"Downloading weights to {location}/weights.pt",
506-
total=int(total_length / 1024) + 1,
507-
):
508-
if chunk:
509-
f.write(chunk)
510-
f.flush()
511-
512-
return
513-
514462
def __exception_check(self, image_path_check=None):
515463
# Check if Image path exists exception check
516464
# (for both hosted URL and local image)

0 commit comments

Comments
 (0)