Skip to content

Commit 2a39441

Browse files
client
1 parent 3208f65 commit 2a39441

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

src/autotrain/client.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
from dataclasses import dataclass
23
from typing import Optional
3-
import os
4+
45
import requests
56

7+
68
"""
79
{
810
"project_name": "string",
@@ -51,6 +53,7 @@
5153
}
5254
"""
5355

56+
5457
@dataclass
5558
class Client:
5659
host: Optional[str] = None
@@ -60,28 +63,36 @@ class Client:
6063
def __post_init__(self):
6164
if self.host is None:
6265
self.host = "https://autotrain-projects-autotrain-advanced.hf.space/"
63-
66+
6467
if self.token is None:
6568
self.token = os.environ.get("HF_TOKEN")
66-
69+
6770
if self.username is None:
6871
self.username = os.environ.get("HF_USERNAME")
6972

7073
if self.token is None or self.username is None:
7174
raise ValueError("Please provide a valid username and token")
72-
73-
self.headers = {
74-
"Authorization": f"Bearer {self.token}",
75-
"Content-Type": "application/json"
76-
}
77-
75+
76+
self.headers = {"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"}
77+
7878
def __str__(self):
7979
return f"Client(host={self.host}, token=****, username={self.username})"
80-
80+
8181
def __repr__(self):
82-
return self.__str__()
83-
84-
def create(self, project_name: str, task: str, base_model: str, hardware: str, params: dict, column_mapping: dict, hub_dataset: str, train_split: str, valid_split: str):
82+
return self.__str__()
83+
84+
def create(
85+
self,
86+
project_name: str,
87+
task: str,
88+
base_model: str,
89+
hardware: str,
90+
params: dict,
91+
column_mapping: dict,
92+
hub_dataset: str,
93+
train_split: str,
94+
valid_split: str,
95+
):
8596
url = f"{self.host}/api/create_project"
8697
data = {
8798
"project_name": project_name,
@@ -93,7 +104,7 @@ def create(self, project_name: str, task: str, base_model: str, hardware: str, p
93104
"column_mapping": column_mapping,
94105
"hub_dataset": hub_dataset,
95106
"train_split": train_split,
96-
"valid_split": valid_split
107+
"valid_split": valid_split,
97108
}
98109
response = requests.post(url, headers=self.headers, json=data)
99110
return response.json()

0 commit comments

Comments
 (0)