|
4 | 4 | import json
|
5 | 5 | import os
|
6 | 6 | import random
|
7 |
| -import sys |
8 | 7 | import urllib
|
9 | 8 |
|
10 | 9 | import cv2
|
11 | 10 | import numpy as np
|
12 | 11 | import requests
|
13 | 12 | from PIL import Image
|
14 |
| -from tqdm import tqdm |
15 | 13 |
|
16 |
| -from roboflow.config import API_URL, OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL |
| 14 | +from roboflow.config import OBJECT_DETECTION_MODEL, OBJECT_DETECTION_URL |
17 | 15 | from roboflow.models.inference import InferenceModel
|
18 | 16 | from roboflow.util.image_utils import check_image_url
|
19 | 17 | from roboflow.util.prediction import PredictionGroup
|
@@ -461,56 +459,6 @@ def view(button):
|
461 | 459 | else:
|
462 | 460 | view(stopButton)
|
463 | 461 |
|
464 |
| - def download(self, format="pt", location="."): |
465 |
| - """ |
466 |
| - Download the weights associated with a model. |
467 |
| -
|
468 |
| - Args: |
469 |
| - format (str): The format of the output. |
470 |
| - - 'pt': returns a PyTorch weights file |
471 |
| - location (str): The location to save the weights file to |
472 |
| - """ |
473 |
| - supported_formats = ["pt"] |
474 |
| - if format not in supported_formats: |
475 |
| - raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}") |
476 |
| - |
477 |
| - workspace, project, version = self.id.rsplit("/") |
478 |
| - |
479 |
| - # get pt url |
480 |
| - pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile" |
481 |
| - |
482 |
| - r = requests.get(pt_api_url, params={"api_key": self.__api_key}) |
483 |
| - |
484 |
| - r.raise_for_status() |
485 |
| - |
486 |
| - pt_weights_url = r.json()["weightsUrl"] |
487 |
| - |
488 |
| - def bar_progress(current, total, width=80): |
489 |
| - progress_message = ( |
490 |
| - "Downloading weights to " |
491 |
| - + location |
492 |
| - + "/weights.pt" |
493 |
| - + ": %d%% [%d / %d] bytes" % (current / total * 100, current, total) |
494 |
| - ) |
495 |
| - sys.stdout.write("\r" + progress_message) |
496 |
| - sys.stdout.flush() |
497 |
| - |
498 |
| - response = requests.get(pt_weights_url, stream=True) |
499 |
| - |
500 |
| - # write the zip file to the desired location |
501 |
| - with open(location + "/weights.pt", "wb") as f: |
502 |
| - total_length = int(response.headers.get("content-length")) |
503 |
| - for chunk in tqdm( |
504 |
| - response.iter_content(chunk_size=1024), |
505 |
| - desc=f"Downloading weights to {location}/weights.pt", |
506 |
| - total=int(total_length / 1024) + 1, |
507 |
| - ): |
508 |
| - if chunk: |
509 |
| - f.write(chunk) |
510 |
| - f.flush() |
511 |
| - |
512 |
| - return |
513 |
| - |
514 | 462 | def __exception_check(self, image_path_check=None):
|
515 | 463 | # Check if Image path exists exception check
|
516 | 464 | # (for both hosted URL and local image)
|
|
0 commit comments