Skip to content

Commit 0479b7e

Browse files
Fix for LocalTrainer.
1 parent 9699a88 commit 0479b7e

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/fmcore/framework/_trainer/LocalTrainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from bears import FileMetadata
1212
from bears.util import String, Timer, safe_validate_arguments
13+
from pydantic import model_validator
1314

1415
from fmcore.framework._algorithm import Algorithm
1516
from fmcore.framework._dataset import Dataset, Datasets, DataSplit
@@ -24,6 +25,12 @@ class LocalTrainer(Trainer):
2425
def initialize(self, **kwargs):
2526
pass
2627

28+
@model_validator(mode="before")
29+
@classmethod
30+
def local_trainer_params(cls, params: Dict) -> Dict:
31+
params: Dict = cls._set_common_trainer_params(params)
32+
return params
33+
2734
@staticmethod
2835
def local_logger(text: str, verbosity: int, tracker: Tracker):
2936
pid: int = os.getpid()

0 commit comments

Comments
 (0)