@@ -28,7 +28,7 @@ def __init__(self) -> None:
2828 self ._job : Optional [Job ] = None
2929
3030 # ------------- lifecycle -------------
31- def init (self , experiment_id : str = "alpha" ) -> None :
31+ def init (self , experiment_id : str = "alpha" , config : Optional [ Dict [ str , Any ]] = None ) -> None :
3232 """
3333 Initialize a job under the given experiment.
3434 If _TFL_JOB_ID environment variable is set, uses that existing job.
@@ -59,6 +59,10 @@ def init(self, experiment_id: str = "alpha") -> None:
5959 # Check for wandb integration and capture URL if available
6060 self ._detect_and_capture_wandb_url ()
6161
62+ # Set config if provided
63+ if config is not None :
64+ self .set_config (config )
65+
6266 def set_config (self , config : Dict [str , Any ]) -> None :
6367 """
6468 Attach configuration to the current job.
@@ -88,6 +92,73 @@ def update_progress(self, progress: int) -> None:
8892 # Check for wandb URL on every progress update
8993 self ._check_and_capture_wandb_url ()
9094
95+ # ------------- checkpoint resume support -------------
96+ def get_checkpoint_to_resume (self ) -> Optional [str ]:
97+ """
98+ Get the checkpoint path to resume training from.
99+
100+ This method checks for checkpoint resume information stored in the job data
101+ when resuming training from a checkpoint.
102+
103+ Returns:
104+ Optional[str]: The full path to the checkpoint to resume from, or None if no
105+ checkpoint resume is requested.
106+ """
107+ if not self ._job :
108+ return None
109+
110+ job_data = self ._job .get_job_data ()
111+ if not job_data :
112+ return None
113+
114+ parent_job_id = job_data .get ('parent_job_id' )
115+ checkpoint_name = job_data .get ('resumed_from_checkpoint' )
116+
117+ if not parent_job_id or not checkpoint_name :
118+ return None
119+
120+ # Build the checkpoint path from parent job's checkpoints directory
121+ checkpoint_path = self .get_parent_job_checkpoint_path (parent_job_id , checkpoint_name )
122+
123+ # Verify the checkpoint exists
124+ if checkpoint_path and os .path .exists (checkpoint_path ):
125+ return checkpoint_path
126+
127+ return None
128+
129+ def get_parent_job_checkpoint_path (self , parent_job_id : str , checkpoint_name : str ) -> Optional [str ]:
130+ """
131+ Get the full path to a checkpoint from a parent job.
132+
133+ This is a helper method that constructs the path to a specific checkpoint
134+ from a parent job's checkpoints directory.
135+
136+ Args:
137+ parent_job_id (str): The ID of the parent job that created the checkpoint
138+ checkpoint_name (str): The name of the checkpoint file or directory
139+
140+ Returns:
141+ Optional[str]: The full path to the checkpoint, or None if it doesn't exist
142+ """
143+ try :
144+ checkpoints_dir = dirs .get_job_checkpoints_dir (parent_job_id )
145+ checkpoint_path = os .path .join (checkpoints_dir , checkpoint_name )
146+
147+ # Security check: ensure the checkpoint path is within the checkpoints directory
148+ checkpoint_path_normalized = os .path .normpath (checkpoint_path )
149+ checkpoints_dir_normalized = os .path .normpath (checkpoints_dir )
150+
151+ if not checkpoint_path_normalized .startswith (checkpoints_dir_normalized + os .sep ):
152+ return None
153+
154+ if os .path .exists (checkpoint_path_normalized ):
155+ return checkpoint_path_normalized
156+
157+ return None
158+ except Exception as e :
159+ print (f"Error getting parent job checkpoint path: { str (e )} " )
160+ return None
161+
91162 # ------------- completion -------------
92163 def finish (
93164 self ,
@@ -509,8 +580,8 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
509580 try :
510581 if hasattr (df , "to_pandas" ) and callable (getattr (df , "to_pandas" )):
511582 df = df .to_pandas ()
512- except Exception :
513- pass
583+ except Exception as e :
584+ print ( f"Warning: Failed to convert dataset to pandas DataFrame: { str ( e ) } " )
514585
515586 # Prepare dataset directory
516587 dataset_id_safe = dataset_id .strip ()
@@ -571,16 +642,17 @@ def save_dataset(self, df, dataset_id: str, additional_metadata: Optional[Dict[s
571642 )
572643 except Exception as e :
573644 # Do not fail the save if metadata write fails; log to job data
645+ print (f"Warning: Failed to create dataset metadata: { str (e )} " )
574646 try :
575647 self ._job .update_job_data_field ("dataset_metadata_error" , str (e )) # type: ignore[union-attr]
576- except Exception :
577- pass
648+ except Exception as e2 :
649+ print ( f"Warning: Failed to log dataset metadata error: { str ( e2 ) } " )
578650
579651 # Track dataset on the job for provenance
580652 try :
581653 self ._job .update_job_data_field ("dataset_id" , dataset_id_safe ) # type: ignore[union-attr]
582- except Exception :
583- pass
654+ except Exception as e :
655+ print ( f"Warning: Failed to track dataset in job_data: { str ( e ) } " )
584656
585657 self .log (f"Dataset saved to '{ output_path } ' and registered as generated dataset '{ dataset_id_safe } '" )
586658 return output_path
@@ -624,8 +696,8 @@ def save_checkpoint(self, source_path: str, name: Optional[str] = None) -> str:
624696 ckpt_list .append (dest )
625697 self ._job .update_job_data_field ("checkpoints" , ckpt_list )
626698 self ._job .update_job_data_field ("latest_checkpoint" , dest )
627- except Exception :
628- pass
699+ except Exception as e :
700+ print ( f"Warning: Failed to track checkpoint in job_data: { str ( e ) } " )
629701
630702 return dest
631703
0 commit comments