Skip to content

Commit cde3e0e

Browse files
authored
Merge pull request #660 from transformerlab/add/use-fsspec
Use fsspec in everything
2 parents 6c14b9d + 5080d60 commit cde3e0e

File tree

100 files changed

+1837
-1395
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

100 files changed

+1837
-1395
lines changed

api.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
)
7979
from transformerlab.shared.request_context import set_current_org_id
8080
from lab.dirs import set_organization_id as lab_set_org_id
81+
from lab import storage
8182

8283
from dotenv import load_dotenv
8384

@@ -91,7 +92,7 @@
9192
# to be overriden by the user.
9293
os.environ["_TFL_SOURCE_CODE_DIR"] = dirs.TFL_SOURCE_CODE_DIR
9394
# The temporary image directory for transformerlab (default; per-request overrides computed in routes)
94-
temp_image_dir = os.path.join(get_workspace_dir(), "temp", "images")
95+
temp_image_dir = storage.join(get_workspace_dir(), "temp", "images")
9596
os.environ["TLAB_TEMP_IMAGE_DIR"] = str(temp_image_dir)
9697

9798
# Check if anything is stored in GPU_ORCHESTRATION_SERVER, and if so, print a message
@@ -243,7 +244,10 @@ async def validation_exception_handler(request, exc):
243244

244245
def spawn_fastchat_controller_subprocess():
245246
global controller_process
246-
logfile = open(os.path.join(dirs.FASTCHAT_LOGS_DIR, "controller.log"), "w")
247+
controller_log_path = storage.join(dirs.FASTCHAT_LOGS_DIR, "controller.log")
248+
# Note: subprocess requires a local file handle, so we use open() directly
249+
# but construct the path using storage.join for workspace consistency
250+
logfile = open(controller_log_path, "w")
247251
port = "21001"
248252

249253
controller_process = subprocess.Popen(
@@ -254,7 +258,7 @@ def spawn_fastchat_controller_subprocess():
254258
"--port",
255259
port,
256260
"--log-file",
257-
os.path.join(dirs.FASTCHAT_LOGS_DIR, "controller.log"),
261+
controller_log_path,
258262
],
259263
stdout=logfile,
260264
stderr=logfile,
@@ -379,7 +383,7 @@ async def server_worker_start(
379383

380384
from lab.dirs import get_global_log_path
381385

382-
with open(get_global_log_path(), "a") as global_log:
386+
with storage.open(get_global_log_path(), "a") as global_log:
383387
global_log.write(f"🏃 Loading Inference Server for {model_name} with {inference_params}\n")
384388

385389
process = await shared.async_run_python_daemon_and_update_status(
@@ -392,7 +396,7 @@ async def server_worker_start(
392396
if exitcode == 99:
393397
from lab.dirs import get_global_log_path
394398

395-
with open(get_global_log_path(), "a") as global_log:
399+
with storage.open(get_global_log_path(), "a") as global_log:
396400
global_log.write(
397401
"GPU (CUDA) Out of Memory: Please try a smaller model or a different inference engine. Restarting the server may free up resources.\n"
398402
)
@@ -403,7 +407,7 @@ async def server_worker_start(
403407
if exitcode is not None and exitcode != 0:
404408
from lab.dirs import get_global_log_path
405409

406-
with open(get_global_log_path(), "a") as global_log:
410+
with storage.open(get_global_log_path(), "a") as global_log:
407411
global_log.write(f"Error loading model: {model_name} with exit code {exitcode}\n")
408412
job = job_get(job_id)
409413
error_msg = None
@@ -415,7 +419,7 @@ async def server_worker_start(
415419
return {"status": "error", "message": error_msg}
416420
from lab.dirs import get_global_log_path
417421

418-
with open(get_global_log_path(), "a") as global_log:
422+
with storage.open(get_global_log_path(), "a") as global_log:
419423
global_log.write(f"Model loaded successfully: {model_name}\n")
420424
return {"status": "success", "job_id": job_id}
421425

requirements-no-gpu-uv.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ tqdm==4.66.5
556556
# peft
557557
# sentence-transformers
558558
# transformers
559-
transformerlab==0.0.44
559+
transformerlab==0.0.45
560560
# via
561561
# -r requirements.in
562562
# transformerlab-inference

requirements-rocm-uv.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ tqdm==4.66.5
559559
# peft
560560
# sentence-transformers
561561
# transformers
562-
transformerlab==0.0.44
562+
transformerlab==0.0.45
563563
# via
564564
# -r requirements-rocm.in
565565
# transformerlab-inference

requirements-rocm.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ hf_xet
3232
macmon-python
3333
mcp[cli]
3434
transformerlab-inference==0.2.51
35-
transformerlab==0.0.44
35+
transformerlab==0.0.45
3636
diffusers==0.33.1
3737
pyrsmi
3838
controlnet_aux==0.0.10

requirements-uv.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ tqdm==4.66.5
594594
# peft
595595
# sentence-transformers
596596
# transformers
597-
transformerlab==0.0.44
597+
transformerlab==0.0.45
598598
# via
599599
# -r requirements.in
600600
# transformerlab-inference

requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ markitdown[all]
3131
hf_xet
3232
macmon-python
3333
transformerlab-inference==0.2.51
34-
transformerlab==0.0.44
34+
transformerlab==0.0.45
3535
diffusers==0.33.1
3636
nvidia-ml-py
3737
mcp[cli]

test/api/test_diffusion.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def test_get_image_by_id_index_out_of_range(client):
252252
with (
253253
patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image,
254254
patch("transformerlab.routers.experiment.diffusion.get_images_dir", return_value="/fake/images"),
255-
patch("os.path.exists", return_value=True),
256-
patch("os.path.isdir", return_value=True),
255+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
256+
patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True),
257257
):
258258
# Create mock image with folder-format
259259
mock_image = MagicMock()
@@ -272,9 +272,9 @@ def test_get_image_info_by_id_success(client):
272272
"""Test getting image metadata by ID"""
273273
with (
274274
patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image,
275-
patch("os.path.exists", return_value=True),
276-
patch("os.path.isdir", return_value=True),
277-
patch("os.listdir", return_value=["0.png", "1.png", "2.png"]),
275+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
276+
patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True),
277+
patch("transformerlab.routers.experiment.diffusion.storage.ls", return_value=["0.png", "1.png", "2.png"]),
278278
):
279279
# Create mock image
280280
mock_image = MagicMock()
@@ -295,9 +295,9 @@ def test_get_image_count_success(client):
295295
"""Test getting image count for an image set"""
296296
with (
297297
patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image,
298-
patch("os.path.exists", return_value=True),
299-
patch("os.path.isdir", return_value=True),
300-
patch("os.listdir", return_value=["0.png", "1.png"]),
298+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
299+
patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True),
300+
patch("transformerlab.routers.experiment.diffusion.storage.ls", return_value=["0.png", "1.png"]),
301301
):
302302
# Create mock image
303303
mock_image = MagicMock()
@@ -317,8 +317,8 @@ def test_delete_image_from_history_not_found(client):
317317
"""Test deleting a non-existent image from history"""
318318
with (
319319
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
320-
patch("os.path.exists", return_value=True),
321-
patch("builtins.open", mock_open(read_data='[{"id": "other-id", "image_path": "/fake/path.png"}]')),
320+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
321+
patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data='[{"id": "other-id", "image_path": "/fake/path.png"}]')),
322322
):
323323
resp = client.delete("/experiment/test-exp-name/diffusion/history/non-existent-id")
324324
assert resp.status_code == 500
@@ -331,12 +331,12 @@ def test_create_dataset_from_history_success(client):
331331
patch("transformerlab.routers.experiment.diffusion.find_image_by_id") as mock_find_image,
332332
patch("transformerlab.routers.experiment.diffusion.Dataset.get") as mock_dataset_get,
333333
patch("transformerlab.routers.experiment.diffusion.create_local_dataset") as mock_create_dataset,
334-
patch("os.makedirs"),
335-
patch("os.path.exists", return_value=True),
336-
patch("os.path.isdir", return_value=True),
337-
patch("os.listdir", return_value=["0.png", "1.png"]),
338-
patch("shutil.copy2"),
339-
patch("builtins.open", mock_open()),
334+
patch("transformerlab.routers.experiment.diffusion.storage.makedirs"),
335+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
336+
patch("transformerlab.routers.experiment.diffusion.storage.isdir", return_value=True),
337+
patch("transformerlab.routers.experiment.diffusion.storage.ls", return_value=["/fake/path/folder/0.png", "/fake/path/folder/1.png"]),
338+
patch("transformerlab.routers.experiment.diffusion.storage.copy_file"),
339+
patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open()),
340340
):
341341
# Mock Dataset.get to raise FileNotFoundError for non-existent dataset (new behavior)
342342
mock_dataset_get.side_effect = FileNotFoundError("Directory for Dataset with id 'test-dataset' not found")
@@ -610,9 +610,9 @@ def test_load_history_success():
610610
"""Test loading history with valid data"""
611611
with (
612612
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
613-
patch("os.path.exists", return_value=True),
613+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
614614
patch(
615-
"builtins.open",
615+
"transformerlab.routers.experiment.diffusion.storage.open",
616616
mock_open(
617617
read_data='[{"id": "test-id", "model": "test-model", "prompt": "test prompt", "adaptor": "", "adaptor_scale": 1.0, "num_inference_steps": 20, "guidance_scale": 7.5, "seed": 42, "image_path": "/fake/path.png", "timestamp": "2023-01-01T00:00:00", "upscaled": false, "upscale_factor": 1, "negative_prompt": "", "eta": 0.0, "clip_skip": 0, "guidance_rescale": 0.0, "height": 512, "width": 512, "generation_time": 5.0, "num_images": 1, "input_image_path": "", "strength": 0.8, "is_img2img": false, "mask_image_path": "", "is_inpainting": false}]'
618618
),
@@ -665,8 +665,8 @@ def test_load_history_with_pagination():
665665

666666
with (
667667
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
668-
patch("os.path.exists", return_value=True),
669-
patch("builtins.open", mock_open(read_data=json.dumps(history_data))),
668+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
669+
patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data=json.dumps(history_data))),
670670
):
671671
from transformerlab.routers.experiment.diffusion import load_history
672672

@@ -683,7 +683,7 @@ def test_load_history_no_file():
683683
"""Test loading history when history file doesn't exist"""
684684
with (
685685
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
686-
patch("os.path.exists", return_value=False),
686+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=False),
687687
):
688688
from transformerlab.routers.experiment.diffusion import load_history
689689

@@ -697,8 +697,8 @@ def test_load_history_invalid_json():
697697
"""Test loading history with corrupted JSON file"""
698698
with (
699699
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
700-
patch("os.path.exists", return_value=True),
701-
patch("builtins.open", mock_open(read_data="invalid json")),
700+
patch("lab.storage.exists", return_value=True),
701+
patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data="invalid json")),
702702
):
703703
from transformerlab.routers.experiment.diffusion import load_history
704704

@@ -769,8 +769,8 @@ def test_find_image_by_id_success():
769769

770770
with (
771771
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
772-
patch("os.path.exists", return_value=True),
773-
patch("builtins.open", mock_open(read_data=json.dumps(history_data))),
772+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
773+
patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data=json.dumps(history_data))),
774774
):
775775
from transformerlab.routers.experiment.diffusion import find_image_by_id
776776

@@ -816,8 +816,8 @@ def test_find_image_by_id_not_found(client):
816816

817817
with (
818818
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
819-
patch("os.path.exists", return_value=True),
820-
patch("builtins.open", mock_open(read_data=json.dumps(history_data))),
819+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=True),
820+
patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data=json.dumps(history_data))),
821821
):
822822
from transformerlab.routers.experiment.diffusion import find_image_by_id
823823

@@ -830,7 +830,7 @@ def test_find_image_by_id_no_file():
830830
"""Test finding an image by ID when history file doesn't exist"""
831831
with (
832832
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
833-
patch("os.path.exists", return_value=False),
833+
patch("transformerlab.routers.experiment.diffusion.storage.exists", return_value=False),
834834
):
835835
from transformerlab.routers.experiment.diffusion import find_image_by_id
836836

@@ -843,8 +843,8 @@ def test_find_image_by_id_invalid_json():
843843
"""Test finding an image by ID with corrupted JSON file"""
844844
with (
845845
patch("transformerlab.routers.experiment.diffusion.get_history_file_path", return_value="/fake/history.json"),
846-
patch("os.path.exists", return_value=True),
847-
patch("builtins.open", mock_open(read_data="invalid json")),
846+
patch("lab.storage.exists", return_value=True),
847+
patch("transformerlab.routers.experiment.diffusion.storage.open", mock_open(read_data="invalid json")),
848848
):
849849
from transformerlab.routers.experiment.diffusion import find_image_by_id
850850

test/api/test_experiment_export.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformerlab.services import experiment_service
44
from transformerlab.services.tasks_service import tasks_service
55
from lab.dirs import get_workspace_dir
6+
from lab import storage
67

78

89
async def test_export_experiment(client):
@@ -66,13 +67,13 @@ async def test_export_experiment(client):
6667
# The response should be a JSON file
6768
assert response.headers["content-type"] == "application/json"
6869

69-
WORKSPACE_DIR = get_workspace_dir()
70+
workspace_dir = get_workspace_dir()
7071

7172
# Read the exported file from workspace directory
72-
export_file = os.path.join(WORKSPACE_DIR, f"{test_experiment_name}_export.json")
73-
assert os.path.exists(export_file)
73+
export_file = storage.join(workspace_dir, f"{test_experiment_name}_export.json")
74+
assert storage.exists(export_file)
7475

75-
with open(export_file, "r") as f:
76+
with storage.open(export_file, "r") as f:
7677
exported_data = json.load(f)
7778

7879
# Verify the exported data structure
@@ -104,5 +105,5 @@ async def test_export_experiment(client):
104105

105106
# Clean up
106107
experiment_service.experiment_delete(experiment_id)
107-
if os.path.exists(export_file):
108-
os.remove(export_file)
108+
if storage.exists(export_file):
109+
storage.rm(export_file)

transformerlab/db/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# --- Centralized Database Configuration ---
22
from lab import HOME_DIR
33

4-
54
db = None # This will hold the aiosqlite connection
65
DATABASE_FILE_NAME = f"{HOME_DIR}/llmlab.sqlite3"
76
DATABASE_URL = f"sqlite+aiosqlite:///{DATABASE_FILE_NAME}"

transformerlab/fastchat_openai_api.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
UsageInfo,
4646
)
4747
from pydantic import BaseModel as PydanticBaseModel
48-
from lab import Experiment
48+
from lab import Experiment, storage
4949

5050
WORKER_API_TIMEOUT = 3600
5151

@@ -233,16 +233,23 @@ def log_prompt(prompt):
233233
from lab.dirs import get_logs_dir
234234

235235
logs_dir = get_logs_dir()
236-
if os.path.exists(os.path.join(logs_dir, "prompt.log")):
237-
if os.path.getsize(os.path.join(logs_dir, "prompt.log")) > MAX_LOG_SIZE_BEFORE_ROTATE:
238-
with open(os.path.join(logs_dir, "prompt.log"), "r") as f:
236+
prompt_log_path = storage.join(logs_dir, "prompt.log")
237+
if storage.exists(prompt_log_path):
238+
# Get file size - for remote storage, we may need to read the file to check size
239+
try:
240+
with storage.open(prompt_log_path, "r") as f:
239241
lines = f.readlines()
240-
with open(os.path.join(logs_dir, "prompt.log"), "w") as f:
241-
f.writelines(lines[-1000:])
242-
with open(os.path.join(logs_dir, f"prompt_{time.strftime('%Y%m%d%H%M%S')}.log"), "w") as f:
243-
f.writelines(lines[:-1000])
244-
245-
with open(os.path.join(logs_dir, "prompt.log"), "a") as f:
242+
file_size = sum(len(line.encode('utf-8')) for line in lines)
243+
if file_size > MAX_LOG_SIZE_BEFORE_ROTATE:
244+
with storage.open(prompt_log_path, "w") as f:
245+
f.writelines(lines[-1000:])
246+
with storage.open(storage.join(logs_dir, f"prompt_{time.strftime('%Y%m%d%H%M%S')}.log"), "w") as f:
247+
f.writelines(lines[:-1000])
248+
except Exception:
249+
# If we can't read the file, just continue with appending
250+
pass
251+
252+
with storage.open(prompt_log_path, "a") as f:
246253
log_entry = {}
247254
log_entry["date"] = time.strftime("%Y-%m-%d %H:%M:%S")
248255
log_entry["log"] = prompt
@@ -254,7 +261,10 @@ def log_prompt(prompt):
254261
async def get_prompt_log():
255262
from lab.dirs import get_logs_dir
256263

257-
return FileResponse(os.path.join(get_logs_dir(), "prompt.log"))
264+
prompt_log_path = storage.join(get_logs_dir(), "prompt.log")
265+
# FileResponse needs a local file path, so use the path string directly
266+
# For remote storage, this would need special handling
267+
return FileResponse(prompt_log_path)
258268

259269

260270
async def check_length(request, prompt, max_tokens):
@@ -526,8 +536,8 @@ async def create_audio_tts(request: AudioGenerationRequest):
526536
exp_obj = Experiment.get(request.experiment_id)
527537
experiment_dir = exp_obj.get_dir()
528538

529-
audio_dir = os.path.join(experiment_dir, "audio")
530-
os.makedirs(audio_dir, exist_ok=True)
539+
audio_dir = storage.join(experiment_dir, "audio")
540+
storage.makedirs(audio_dir, exist_ok=True)
531541

532542
gen_params = {
533543
"audio_dir": audio_dir,
@@ -560,16 +570,16 @@ async def create_audio_tts(request: AudioGenerationRequest):
560570
async def upload_audio_reference(experimentId: str, audio: UploadFile = File(...)):
561571
exp_obj = Experiment(experimentId)
562572
experiment_dir = exp_obj.get_dir()
563-
uploaded_audio_dir = os.path.join(experiment_dir, "uploaded_audio")
564-
os.makedirs(uploaded_audio_dir, exist_ok=True)
573+
uploaded_audio_dir = storage.join(experiment_dir, "uploaded_audio")
574+
storage.makedirs(uploaded_audio_dir, exist_ok=True)
565575

566576
file_prefix = str(uuid.uuid4())
567577
_, ext = os.path.splitext(audio.filename)
568-
file_path = os.path.join(uploaded_audio_dir, file_prefix + ext)
578+
file_path = storage.join(uploaded_audio_dir, file_prefix + ext)
569579

570580
# Save the uploaded file
571-
with open(file_path, "wb") as f:
572-
content = await audio.read()
581+
content = await audio.read()
582+
with storage.open(file_path, "wb") as f:
573583
f.write(content)
574584

575585
return JSONResponse({"audioPath": file_path})

0 commit comments

Comments
 (0)