Skip to content

Commit 53bbb1a

Browse files
committed
fixes, adapts, update sdk version
1 parent 9dce9e6 commit 53bbb1a

File tree

3 files changed

+36
-23
lines changed

3 files changed

+36
-23
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "transformerlab"
7-
version = "0.0.44"
7+
version = "0.0.45"
88
description = "Python SDK for Transformer Lab"
99
readme = "README.md"
1010
requires-python = ">=3.10"
1111
authors = [{ name = "Transformer Lab", email = "developers@transformerlab.ai" }]
1212
license = { file = "LICENSE" }
13-
dependencies = ["werkzeug", "pytest", "fsspec", "s3fs"]
13+
dependencies = ["werkzeug", "pytest", "wandb", "fsspec", "s3fs"]
1414

1515
[project.urls]
1616
"Homepage" = "https://github.com/transformerlab/transformerlab-sdk"

src/lab/dirs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ def get_job_eval_results_dir(job_id: str | int) -> str:
223223
Return the eval_results directory for a specific job, creating it if needed.
224224
Example: ~/.transformerlab/workspace/jobs/<job_id>/eval_results
225225
"""
226-
path = os.path.join(get_job_dir(job_id), "eval_results")
227-
os.makedirs(name=path, exist_ok=True)
226+
path = storage.join(get_job_dir(job_id), "eval_results")
227+
storage.makedirs(path, exist_ok=True)
228228
return path
229229

230230

src/lab/lab_facade.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional, Dict, Any, Union
55
import os
66
import io
7+
import posixpath
78

89
from .experiment import Experiment
910
from .job import Job
@@ -121,7 +122,7 @@ def get_checkpoint_to_resume(self) -> Optional[str]:
121122
checkpoint_path = self.get_parent_job_checkpoint_path(parent_job_id, checkpoint_name)
122123

123124
# Verify the checkpoint exists
124-
if checkpoint_path and os.path.exists(checkpoint_path):
125+
if checkpoint_path and storage.exists(checkpoint_path):
125126
return checkpoint_path
126127

127128
return None
@@ -142,16 +143,19 @@ def get_parent_job_checkpoint_path(self, parent_job_id: str, checkpoint_name: st
142143
"""
143144
try:
144145
checkpoints_dir = dirs.get_job_checkpoints_dir(parent_job_id)
145-
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_name)
146+
checkpoint_path = storage.join(checkpoints_dir, checkpoint_name)
146147

147148
# Security check: ensure the checkpoint path is within the checkpoints directory
148-
checkpoint_path_normalized = os.path.normpath(checkpoint_path)
149-
checkpoints_dir_normalized = os.path.normpath(checkpoints_dir)
149+
# Normalize paths using posixpath for cross-platform compatibility (works for both local and remote storage)
150+
checkpoint_path_normalized = posixpath.normpath(checkpoint_path).rstrip("/")
151+
checkpoints_dir_normalized = posixpath.normpath(checkpoints_dir).rstrip("/")
150152

151-
if not checkpoint_path_normalized.startswith(checkpoints_dir_normalized + os.sep):
153+
# Check if checkpoint path is strictly within checkpoints directory (not the directory itself)
154+
# For remote storage (s3://, etc.), ensure we're checking within the same bucket/path
155+
if not checkpoint_path_normalized.startswith(checkpoints_dir_normalized + "/"):
152156
return None
153157

154-
if os.path.exists(checkpoint_path_normalized):
158+
if storage.exists(checkpoint_path_normalized):
155159
return checkpoint_path_normalized
156160

157161
return None
@@ -365,8 +369,11 @@ def save_artifact(
365369
if type == "model":
366370
if not isinstance(source_path, str) or source_path.strip() == "":
367371
raise ValueError("source_path must be a non-empty string when type='model'")
368-
src = os.path.abspath(source_path)
369-
if not os.path.exists(src):
372+
src = source_path
373+
# For local paths, resolve to absolute path; for remote paths (s3://, etc.), use as-is
374+
if not src.startswith(("s3://", "gs://", "abfs://", "gcs://", "http://", "https://")):
375+
src = os.path.abspath(src)
376+
if not storage.exists(src):
370377
raise FileNotFoundError(f"Model source does not exist: {src}")
371378

372379
# Get model-specific parameters from config
@@ -397,7 +404,7 @@ def save_artifact(
397404
if isinstance(name, str) and name.strip() != "":
398405
base_name = f"{job_id}_{name}"
399406
else:
400-
base_name = f"{job_id}_{os.path.basename(src)}"
407+
base_name = f"{job_id}_{posixpath.basename(src)}"
401408

402409
# Save to main workspace models directory for Model Zoo visibility
403410
models_dir = dirs.get_models_dir()
@@ -407,7 +414,7 @@ def save_artifact(
407414
storage.makedirs(models_dir, exist_ok=True)
408415

409416
# Copy file or directory using storage module
410-
if os.path.isdir(src):
417+
if storage.isdir(src):
411418
if storage.exists(dest):
412419
storage.rm_tree(dest)
413420
storage.copy_dir(src, dest)
@@ -429,7 +436,7 @@ def save_artifact(
429436
pipeline_tag = model_service.fetch_pipeline_tag(parent_model)
430437

431438
# Determine model_filename for single-file models
432-
model_filename = "" if storage.isdir(dest) else os.path.basename(dest)
439+
model_filename = "" if storage.isdir(dest) else posixpath.basename(dest)
433440

434441
# Prepare json_data with basic info
435442
json_data = {
@@ -508,8 +515,11 @@ def save_artifact(
508515
# Handle file path input (original behavior)
509516
if not isinstance(source_path, str) or source_path.strip() == "":
510517
raise ValueError("source_path must be a non-empty string")
511-
src = os.path.abspath(source_path)
512-
if not os.path.exists(src):
518+
src = source_path
519+
# For local paths, resolve to absolute path; for remote paths (s3://, etc.), use as-is
520+
if not src.startswith(("s3://", "gs://", "abfs://", "gcs://", "http://", "https://")):
521+
src = os.path.abspath(src)
522+
if not storage.exists(src):
513523
raise FileNotFoundError(f"Artifact source does not exist: {src}")
514524

515525
# Determine destination directory based on type
@@ -518,14 +528,14 @@ def save_artifact(
518528
else:
519529
dest_dir = dirs.get_job_artifacts_dir(job_id)
520530

521-
base_name = name if (isinstance(name, str) and name.strip() != "") else os.path.basename(src)
531+
base_name = name if (isinstance(name, str) and name.strip() != "") else posixpath.basename(src)
522532
dest = storage.join(dest_dir, base_name)
523533

524534
# Create parent directories
525535
storage.makedirs(dest_dir, exist_ok=True)
526536

527537
# Copy file or directory
528-
if os.path.isdir(src):
538+
if storage.isdir(src):
529539
if storage.exists(dest):
530540
storage.rm_tree(dest)
531541
storage.copy_dir(src, dest)
@@ -665,20 +675,23 @@ def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str:
665675
self._ensure_initialized()
666676
if not isinstance(source_path, str) or source_path.strip() == "":
667677
raise ValueError("source_path must be a non-empty string")
668-
src = os.path.abspath(source_path)
669-
if not os.path.exists(src):
678+
src = source_path
679+
# For local paths, resolve to absolute path; for remote paths (s3://, etc.), use as-is
680+
if not src.startswith(("s3://", "gs://", "abfs://", "gcs://", "http://", "https://")):
681+
src = os.path.abspath(src)
682+
if not storage.exists(src):
670683
raise FileNotFoundError(f"Checkpoint source does not exist: {src}")
671684

672685
job_id = self._job.id # type: ignore[union-attr]
673686
ckpts_dir = dirs.get_job_checkpoints_dir(job_id)
674-
base_name = name if (isinstance(name, str) and name.strip() != "") else os.path.basename(src)
687+
base_name = name if (isinstance(name, str) and name.strip() != "") else posixpath.basename(src)
675688
dest = storage.join(ckpts_dir, base_name)
676689

677690
# Create parent directories
678691
storage.makedirs(ckpts_dir, exist_ok=True)
679692

680693
# Copy file or directory
681-
if os.path.isdir(src):
694+
if storage.isdir(src):
682695
if storage.exists(dest):
683696
storage.rm_tree(dest)
684697
storage.copy_dir(src, dest)

0 commit comments

Comments
 (0)