Skip to content

Commit 8c997ab

Browse files
authored
Merge pull request #137 from roboflow/ptDownload
PT Weights Download
2 parents 071bf45 + 32a220b commit 8c997ab

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from roboflow.core.workspace import Workspace
1313
from roboflow.util.general import write_line
1414

15-
__version__ = "1.0.3"
15+
__version__ = "1.0.4"
1616

1717

1818
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,32 @@ def live_plot(epochs, mAP, loss, title=""):
433433
# return the model object
434434
return self.model
435435

436+
def get_pt_weights(self, location="."):
437+
workspace, project, *_ = self.id.rsplit("/")
438+
439+
# get pt url
440+
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
441+
442+
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
443+
444+
r.raise_for_status()
445+
446+
pt_weights_url = r.json()["weightsUrl"]
447+
448+
def bar_progress(current, total, width=80):
449+
progress_message = (
450+
"Downloading weights to "
451+
+ location
452+
+ "/weights.pt"
453+
+ ": %d%% [%d / %d] bytes" % (current / total * 100, current, total)
454+
)
455+
sys.stdout.write("\r" + progress_message)
456+
sys.stdout.flush()
457+
458+
wget.download(pt_weights_url, out=location + "/weights.pt", bar=bar_progress)
459+
460+
return
461+
436462
# @warn_for_wrong_dependencies_versions([("ultralytics", "<=", "8.0.20")])
437463
def deploy(self, model_type: str, model_path: str) -> None:
438464
"""Uploads provided weights file to Roboflow

0 commit comments

Comments
 (0)