Skip to content

Commit 1761163

Browse files
Client (#801)
1 parent 2a39441 commit 1761163

File tree

7 files changed

+321
-88
lines changed

7 files changed

+321
-88
lines changed

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@
3333

3434
TESTS_REQUIRE = ["pytest"]
3535

36+
CLIENT_REQUIRES = ["requests", "loguru"]
37+
3638

3739
EXTRAS_REQUIRE = {
40+
"base": INSTALL_REQUIRES,
3841
"dev": INSTALL_REQUIRES + QUALITY_REQUIRE + TESTS_REQUIRE,
3942
"quality": INSTALL_REQUIRES + QUALITY_REQUIRE,
4043
"docs": INSTALL_REQUIRES
@@ -45,6 +48,7 @@
4548
"sphinx-rtd-theme==0.4.3",
4649
"sphinx-copybutton",
4750
],
51+
"client": CLIENT_REQUIRES,
4852
}
4953

5054
setup(

src/autotrain/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@
2525

2626
import warnings
2727

28-
import torch._dynamo
2928

30-
from autotrain.logging import Logger
29+
try:
30+
import torch._dynamo
31+
32+
torch._dynamo.config.suppress_errors = True
33+
except ImportError:
34+
pass
3135

36+
from autotrain.logging import Logger
3237

33-
torch._dynamo.config.suppress_errors = True
3438

3539
warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")
3640
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")

src/autotrain/app/api_routes.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from fastapi import APIRouter, Depends, HTTPException, Request, status
55
from fastapi.responses import JSONResponse
6-
from huggingface_hub import HfApi
6+
from huggingface_hub import HfApi, constants
7+
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
78
from pydantic import BaseModel, create_model, model_validator
89

910
from autotrain import __version__, logger
@@ -569,6 +570,10 @@ def validate_params(cls, values):
569570
return values
570571

571572

573+
class JobIDModel(BaseModel):
574+
jid: str
575+
576+
572577
api_router = APIRouter()
573578

574579

@@ -690,33 +695,16 @@ async def api_version():
690695
return {"version": __version__}
691696

692697

693-
@api_router.get("/logs", response_class=JSONResponse)
694-
async def api_logs(job_id: str, token: bool = Depends(api_auth)):
695-
"""
696-
Fetch logs for a specific job.
697-
698-
Args:
699-
job_id (str): The ID of the job for which logs are to be fetched.
700-
token (bool, optional): Authentication token, defaults to the result of api_auth dependency.
701-
702-
Returns:
703-
dict: A dictionary containing the logs, success status, and a message.
704-
"""
705-
# project = AutoTrainProject(job_id=job_id, token=token)
706-
# logs = project.get_logs()
707-
return {"logs": "Not implemented yet", "success": False, "message": "Not implemented yet"}
708-
709-
710-
@api_router.get("/stop_training", response_class=JSONResponse)
711-
async def api_stop_training(job_id: str, token: bool = Depends(api_auth)):
698+
@api_router.post("/stop_training", response_class=JSONResponse)
699+
async def api_stop_training(job: JobIDModel, token: bool = Depends(api_auth)):
712700
"""
713701
Stops the training job with the given job ID.
714702
715703
This asynchronous function pauses the training job identified by the provided job ID.
716704
It uses the Hugging Face API to pause the space associated with the job.
717705
718706
Args:
719-
job_id (str): The ID of the job to stop.
707+
job (JobIDModel): The job model containing the job ID.
720708
token (bool, optional): The authentication token, provided by dependency injection.
721709
722710
Returns:
@@ -728,9 +716,59 @@ async def api_stop_training(job_id: str, token: bool = Depends(api_auth)):
728716
Exception: If there is an error while attempting to stop the training job.
729717
"""
730718
hf_api = HfApi(token=token)
719+
job_id = job.jid
731720
try:
732721
hf_api.pause_space(repo_id=job_id)
733722
except Exception as e:
734723
logger.error(f"Failed to stop training: {e}")
735724
return {"message": f"Failed to stop training for {job_id}: {e}", "success": False}
736725
return {"message": f"Training stopped for {job_id}", "success": True}
726+
727+
728+
@api_router.post("/logs", response_class=JSONResponse)
729+
async def api_logs(job: JobIDModel, token: bool = Depends(api_auth)):
730+
"""
731+
Fetch logs for a given job.
732+
733+
This endpoint retrieves logs for a specified job by its job ID. It first obtains a JWT token
734+
to authenticate the request and then fetches the logs from the Hugging Face API.
735+
736+
Args:
737+
job (JobIDModel): The job model containing the job ID.
738+
token (bool, optional): Dependency injection for API authentication. Defaults to Depends(api_auth).
739+
740+
Returns:
741+
JSONResponse: A JSON response containing the logs, success status, and a message.
742+
743+
Raises:
744+
Exception: If there is an error fetching the logs, the exception message is returned in the response.
745+
"""
746+
job_id = job.jid
747+
jwt_url = f"{constants.ENDPOINT}/api/spaces/{job_id}/jwt"
748+
response = get_session().get(jwt_url, headers=build_hf_headers())
749+
hf_raise_for_status(response)
750+
jwt_token = response.json()["token"] # works for 24h (see "exp" field)
751+
752+
# fetch the logs
753+
logs_url = f"https://api.hf.space/v1/{job_id}/logs/run"
754+
755+
_logs = []
756+
try:
757+
with get_session().get(logs_url, headers=build_hf_headers(token=jwt_token), stream=True) as response:
758+
hf_raise_for_status(response)
759+
for line in response.iter_lines():
760+
if not line.startswith(b"data: "):
761+
continue
762+
line_data = line[len(b"data: ") :]
763+
try:
764+
event = json.loads(line_data.decode())
765+
except json.JSONDecodeError:
766+
continue # ignore (for example, empty lines or `b': keep-alive'`)
767+
_logs.append((event["timestamp"], event["data"]))
768+
769+
# convert logs to a string
770+
_logs = "\n".join([f"{timestamp}: {data}" for timestamp, data in _logs])
771+
772+
return {"logs": _logs, "success": True, "message": "Logs fetched successfully"}
773+
except Exception as e:
774+
return {"logs": str(e), "success": False, "message": "Failed to fetch logs"}

src/autotrain/app/params.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ def _munge_common_params(self):
250250
def _munge_params_sent_transformers(self):
251251
_params = self._munge_common_params()
252252
_params["model"] = self.base_model
253-
_params["log"] = "tensorboard"
253+
if "log" not in _params:
254+
_params["log"] = "tensorboard"
254255
if not self.using_hub_dataset:
255256
_params["sentence1_column"] = "autotrain_sentence1"
256257
_params["sentence2_column"] = "autotrain_sentence2"
@@ -291,7 +292,8 @@ def _munge_params_llm(self):
291292
"rejected_text" if not self.api else "rejected_text_column", "rejected_text"
292293
)
293294
_params["train_split"] = self.train_split
294-
_params["log"] = "tensorboard"
295+
if "log" not in _params:
296+
_params["log"] = "tensorboard"
295297

296298
trainer = self.task.split(":")[1]
297299
if trainer != "generic":
@@ -321,7 +323,8 @@ def _munge_params_vlm(self):
321323
)
322324
_params["train_split"] = self.train_split
323325
_params["valid_split"] = self.valid_split
324-
_params["log"] = "tensorboard"
326+
if "log" not in _params:
327+
_params["log"] = "tensorboard"
325328

326329
trainer = self.task.split(":")[1]
327330
_params["trainer"] = trainer.lower()
@@ -335,7 +338,8 @@ def _munge_params_vlm(self):
335338
def _munge_params_text_clf(self):
336339
_params = self._munge_common_params()
337340
_params["model"] = self.base_model
338-
_params["log"] = "tensorboard"
341+
if "log" not in _params:
342+
_params["log"] = "tensorboard"
339343
if not self.using_hub_dataset:
340344
_params["text_column"] = "autotrain_text"
341345
_params["target_column"] = "autotrain_label"
@@ -350,7 +354,8 @@ def _munge_params_text_clf(self):
350354
def _munge_params_extractive_qa(self):
351355
_params = self._munge_common_params()
352356
_params["model"] = self.base_model
353-
_params["log"] = "tensorboard"
357+
if "log" not in _params:
358+
_params["log"] = "tensorboard"
354359
if not self.using_hub_dataset:
355360
_params["text_column"] = "autotrain_text"
356361
_params["question_column"] = "autotrain_question"
@@ -369,7 +374,8 @@ def _munge_params_extractive_qa(self):
369374
def _munge_params_text_reg(self):
370375
_params = self._munge_common_params()
371376
_params["model"] = self.base_model
372-
_params["log"] = "tensorboard"
377+
if "log" not in _params:
378+
_params["log"] = "tensorboard"
373379
if not self.using_hub_dataset:
374380
_params["text_column"] = "autotrain_text"
375381
_params["target_column"] = "autotrain_label"
@@ -384,7 +390,8 @@ def _munge_params_text_reg(self):
384390
def _munge_params_token_clf(self):
385391
_params = self._munge_common_params()
386392
_params["model"] = self.base_model
387-
_params["log"] = "tensorboard"
393+
if "log" not in _params:
394+
_params["log"] = "tensorboard"
388395
if not self.using_hub_dataset:
389396
_params["tokens_column"] = "autotrain_text"
390397
_params["tags_column"] = "autotrain_label"
@@ -400,7 +407,8 @@ def _munge_params_token_clf(self):
400407
def _munge_params_seq2seq(self):
401408
_params = self._munge_common_params()
402409
_params["model"] = self.base_model
403-
_params["log"] = "tensorboard"
410+
if "log" not in _params:
411+
_params["log"] = "tensorboard"
404412
if not self.using_hub_dataset:
405413
_params["text_column"] = "autotrain_text"
406414
_params["target_column"] = "autotrain_label"
@@ -416,7 +424,8 @@ def _munge_params_seq2seq(self):
416424
def _munge_params_img_clf(self):
417425
_params = self._munge_common_params()
418426
_params["model"] = self.base_model
419-
_params["log"] = "tensorboard"
427+
if "log" not in _params:
428+
_params["log"] = "tensorboard"
420429
if not self.using_hub_dataset:
421430
_params["image_column"] = "autotrain_image"
422431
_params["target_column"] = "autotrain_label"
@@ -432,7 +441,8 @@ def _munge_params_img_clf(self):
432441
def _munge_params_img_reg(self):
433442
_params = self._munge_common_params()
434443
_params["model"] = self.base_model
435-
_params["log"] = "tensorboard"
444+
if "log" not in _params:
445+
_params["log"] = "tensorboard"
436446
if not self.using_hub_dataset:
437447
_params["image_column"] = "autotrain_image"
438448
_params["target_column"] = "autotrain_label"
@@ -448,7 +458,8 @@ def _munge_params_img_reg(self):
448458
def _munge_params_img_obj_det(self):
449459
_params = self._munge_common_params()
450460
_params["model"] = self.base_model
451-
_params["log"] = "tensorboard"
461+
if "log" not in _params:
462+
_params["log"] = "tensorboard"
452463
if not self.using_hub_dataset:
453464
_params["image_column"] = "autotrain_image"
454465
_params["objects_column"] = "autotrain_objects"

src/autotrain/backends/spaces.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def _create_readme(self):
4141
_readme += "colorTo: indigo\n"
4242
_readme += "sdk: docker\n"
4343
_readme += "pinned: false\n"
44+
_readme += "tags:\n"
45+
_readme += "- autotrain\n"
4446
_readme += "duplicated_from: autotrain-projects/autotrain-advanced\n"
4547
_readme += "---\n"
4648
_readme = io.BytesIO(_readme.encode())

0 commit comments

Comments
 (0)