Skip to content

Commit 9dce9e6

Browse files
committed
Merge branch 'main' of https://github.com/transformerlab/transformerlab-sdk into add/use-fsspec
2 parents 08675d5 + 3aa9f05 commit 9dce9e6

File tree

2 files changed

+102
-19
lines changed

2 files changed

+102
-19
lines changed

scripts/examples/trl_train_script.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import os
8+
import argparse
89
from datetime import datetime
910
from time import sleep
1011
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
@@ -122,6 +123,11 @@ def train_with_trl(quick_test=True):
122123
lab.init()
123124
lab.set_config(training_config)
124125

126+
# Check if we should resume from a checkpoint
127+
checkpoint = lab.get_checkpoint_to_resume()
128+
if checkpoint:
129+
lab.log(f"📁 Resuming training from checkpoint: {checkpoint}")
130+
125131
# Log start time
126132
start_time = datetime.now()
127133
mode = "Quick test" if quick_test else "Full training"
@@ -162,17 +168,17 @@ def train_with_trl(quick_test=True):
162168
lab.log("Loading model and tokenizer...")
163169
try:
164170
from transformers import AutoTokenizer, AutoModelForCausalLM
165-
171+
166172
model_name = training_config["model_name"]
167173
tokenizer = AutoTokenizer.from_pretrained(model_name)
168174
model = AutoModelForCausalLM.from_pretrained(model_name)
169-
175+
170176
# Add pad token if it doesn't exist
171177
if tokenizer.pad_token is None:
172178
tokenizer.pad_token = tokenizer.eos_token
173179

174180
lab.log(f"Loaded model: {model_name}")
175-
181+
176182
except ImportError:
177183
lab.log("⚠️ Transformers not available, skipping real training")
178184
lab.finish("Training skipped - transformers not available")
@@ -207,6 +213,8 @@ def train_with_trl(quick_test=True):
207213
remove_unused_columns=False,
208214
push_to_hub=False,
209215
dataset_text_field="text", # Move dataset_text_field to SFTConfig
216+
resume_from_checkpoint=checkpoint if checkpoint else None,
217+
bf16=False, # Disable bf16 for compatibility with older GPUs
210218
# Enable automatic checkpoint saving
211219
save_total_limit=3, # Keep only the last 3 checkpoints to save disk space
212220
save_strategy="steps", # Save checkpoints every save_steps
@@ -440,15 +448,18 @@ def train_with_trl(quick_test=True):
440448

441449

442450
if __name__ == "__main__":
443-
import sys
444451

445-
# Check if user wants full training or quick test
446-
quick_test = False # Default to quick test
447-
if len(sys.argv) > 1 and sys.argv[1] == "--quick-training":
448-
quick_test = True
452+
parser = argparse.ArgumentParser(description="Train a model with automatic checkpoint resume support.")
453+
parser.add_argument("--quick-training", action="store_true", help="Run in quick test mode")
454+
455+
args = parser.parse_args()
456+
457+
quick_test = args.quick_training
458+
459+
if quick_test:
449460
print("🚀 Running quick test mode...")
450461
else:
451-
print("🚀 Running full training mode (use --quick-training for quick test)...")
452-
462+
print("🚀 Running full training mode...")
463+
453464
result = train_with_trl(quick_test=quick_test)
454465
print("Training result:", result)

src/lab/lab_facade.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self) -> None:
2828
self._job: Optional[Job] = None
2929

3030
# ------------- lifecycle -------------
31-
def init(self, experiment_id: str = "alpha") -> None:
31+
def init(self, experiment_id: str = "alpha", config: Optional[Dict[str, Any]] = None) -> None:
3232
"""
3333
Initialize a job under the given experiment.
3434
If _TFL_JOB_ID environment variable is set, uses that existing job.
@@ -59,6 +59,10 @@ def init(self, experiment_id: str = "alpha") -> None:
5959
# Check for wandb integration and capture URL if available
6060
self._detect_and_capture_wandb_url()
6161

62+
# Set config if provided
63+
if config is not None:
64+
self.set_config(config)
65+
6266
def set_config(self, config: Dict[str, Any]) -> None:
6367
"""
6468
Attach configuration to the current job.
@@ -88,6 +92,73 @@ def update_progress(self, progress: int) -> None:
8892
# Check for wandb URL on every progress update
8993
self._check_and_capture_wandb_url()
9094

95+
# ------------- checkpoint resume support -------------
96+
def get_checkpoint_to_resume(self) -> Optional[str]:
97+
"""
98+
Get the checkpoint path to resume training from.
99+
100+
This method checks for checkpoint resume information stored in the job data
101+
when resuming training from a checkpoint.
102+
103+
Returns:
104+
Optional[str]: The full path to the checkpoint to resume from, or None if no
105+
checkpoint resume is requested.
106+
"""
107+
if not self._job:
108+
return None
109+
110+
job_data = self._job.get_job_data()
111+
if not job_data:
112+
return None
113+
114+
parent_job_id = job_data.get('parent_job_id')
115+
checkpoint_name = job_data.get('resumed_from_checkpoint')
116+
117+
if not parent_job_id or not checkpoint_name:
118+
return None
119+
120+
# Build the checkpoint path from parent job's checkpoints directory
121+
checkpoint_path = self.get_parent_job_checkpoint_path(parent_job_id, checkpoint_name)
122+
123+
# Verify the checkpoint exists
124+
if checkpoint_path and os.path.exists(checkpoint_path):
125+
return checkpoint_path
126+
127+
return None
128+
129+
def get_parent_job_checkpoint_path(self, parent_job_id: str, checkpoint_name: str) -> Optional[str]:
130+
"""
131+
Get the full path to a checkpoint from a parent job.
132+
133+
This is a helper method that constructs the path to a specific checkpoint
134+
from a parent job's checkpoints directory.
135+
136+
Args:
137+
parent_job_id (str): The ID of the parent job that created the checkpoint
138+
checkpoint_name (str): The name of the checkpoint file or directory
139+
140+
Returns:
141+
Optional[str]: The full path to the checkpoint, or None if it doesn't exist
142+
"""
143+
try:
144+
checkpoints_dir = dirs.get_job_checkpoints_dir(parent_job_id)
145+
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_name)
146+
147+
# Security check: ensure the checkpoint path is within the checkpoints directory
148+
checkpoint_path_normalized = os.path.normpath(checkpoint_path)
149+
checkpoints_dir_normalized = os.path.normpath(checkpoints_dir)
150+
151+
if not checkpoint_path_normalized.startswith(checkpoints_dir_normalized + os.sep):
152+
return None
153+
154+
if os.path.exists(checkpoint_path_normalized):
155+
return checkpoint_path_normalized
156+
157+
return None
158+
except Exception as e:
159+
print(f"Error getting parent job checkpoint path: {str(e)}")
160+
return None
161+
91162
# ------------- completion -------------
92163
def finish(
93164
self,
@@ -509,8 +580,8 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
509580
try:
510581
if hasattr(df, "to_pandas") and callable(getattr(df, "to_pandas")):
511582
df = df.to_pandas()
512-
except Exception:
513-
pass
583+
except Exception as e:
584+
print(f"Warning: Failed to convert dataset to pandas DataFrame: {str(e)}")
514585

515586
# Prepare dataset directory
516587
dataset_id_safe = dataset_id.strip()
@@ -571,16 +642,17 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
571642
)
572643
except Exception as e:
573644
# Do not fail the save if metadata write fails; log to job data
645+
print(f"Warning: Failed to create dataset metadata: {str(e)}")
574646
try:
575647
self._job.update_job_data_field("dataset_metadata_error", str(e)) # type: ignore[union-attr]
576-
except Exception:
577-
pass
648+
except Exception as e2:
649+
print(f"Warning: Failed to log dataset metadata error: {str(e2)}")
578650

579651
# Track dataset on the job for provenance
580652
try:
581653
self._job.update_job_data_field("dataset_id", dataset_id_safe) # type: ignore[union-attr]
582-
except Exception:
583-
pass
654+
except Exception as e:
655+
print(f"Warning: Failed to track dataset in job_data: {str(e)}")
584656

585657
self.log(f"Dataset saved to '{output_path}' and registered as generated dataset '{dataset_id_safe}'")
586658
return output_path
@@ -624,8 +696,8 @@ def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str:
624696
ckpt_list.append(dest)
625697
self._job.update_job_data_field("checkpoints", ckpt_list)
626698
self._job.update_job_data_field("latest_checkpoint", dest)
627-
except Exception:
628-
pass
699+
except Exception as e:
700+
print(f"Warning: Failed to track checkpoint in job_data: {str(e)}")
629701

630702
return dest
631703

0 commit comments

Comments
 (0)