Skip to content

Commit 1c37df3

Browse files
authored
Merge pull request #138 from roboflow/ptDownload2
Pt download2
2 parents 8c997ab + 90e787e commit 1c37df3

File tree

3 files changed

+36
-28
lines changed

3 files changed

+36
-28
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from roboflow.core.workspace import Workspace
1313
from roboflow.util.general import write_line
1414

15-
__version__ = "1.0.4"
15+
__version__ = "1.0.5"
1616

1717

1818
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -433,32 +433,6 @@ def live_plot(epochs, mAP, loss, title=""):
433433
# return the model object
434434
return self.model
435435

436-
def get_pt_weights(self, location="."):
437-
workspace, project, *_ = self.id.rsplit("/")
438-
439-
# get pt url
440-
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
441-
442-
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
443-
444-
r.raise_for_status()
445-
446-
pt_weights_url = r.json()["weightsUrl"]
447-
448-
def bar_progress(current, total, width=80):
449-
progress_message = (
450-
"Downloading weights to "
451-
+ location
452-
+ "/weights.pt"
453-
+ ": %d%% [%d / %d] bytes" % (current / total * 100, current, total)
454-
)
455-
sys.stdout.write("\r" + progress_message)
456-
sys.stdout.flush()
457-
458-
wget.download(pt_weights_url, out=location + "/weights.pt", bar=bar_progress)
459-
460-
return
461-
462436
# @warn_for_wrong_dependencies_versions([("ultralytics", "<=", "8.0.20")])
463437
def deploy(self, model_type: str, model_path: str) -> None:
464438
"""Uploads provided weights file to Roboflow

roboflow/models/object_detection.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
import json
55
import os
66
import random
7+
import sys
78
import urllib
89
from pathlib import Path
910

1011
import cv2
1112
import matplotlib.pyplot as plt
1213
import numpy as np
1314
import requests
15+
import wget
1416
from PIL import Image
1517

16-
from roboflow.config import OBJECT_DETECTION_MODEL
18+
from roboflow.config import API_URL, OBJECT_DETECTION_MODEL
1719
from roboflow.util.image_utils import check_image_url
1820
from roboflow.util.prediction import PredictionGroup
1921
from roboflow.util.versions import (
@@ -459,6 +461,38 @@ def view(button):
459461
else:
460462
view(stopButton)
461463

464+
def download(self, format="pt", location="."):
465+
supported_formats = ["pt"]
466+
if format not in supported_formats:
467+
raise Exception(
468+
f"Unsupported format {format}. Must be one of {supported_formats}"
469+
)
470+
471+
workspace, project, version = self.id.rsplit("/")
472+
473+
# get pt url
474+
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
475+
476+
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
477+
478+
r.raise_for_status()
479+
480+
pt_weights_url = r.json()["weightsUrl"]
481+
482+
def bar_progress(current, total, width=80):
483+
progress_message = (
484+
"Downloading weights to "
485+
+ location
486+
+ "/weights.pt"
487+
+ ": %d%% [%d / %d] bytes" % (current / total * 100, current, total)
488+
)
489+
sys.stdout.write("\r" + progress_message)
490+
sys.stdout.flush()
491+
492+
wget.download(pt_weights_url, out=location + "/weights.pt", bar=bar_progress)
493+
494+
return
495+
462496
def __exception_check(self, image_path_check=None):
463497
# Check if Image path exists exception check (for both hosted URL and local image)
464498
if image_path_check is not None:

0 commit comments

Comments
 (0)