From ad321d989520fc738c1fb772d43ca2288483e803 Mon Sep 17 00:00:00 2001 From: RealTapeL <167164141+RealTapeL@users.noreply.github.com> Date: Wed, 18 Jun 2025 15:18:17 +0800 Subject: [PATCH] Update run.py --- run.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/run.py b/run.py index 793a53ac2..f62440cf3 100644 --- a/run.py +++ b/run.py @@ -11,7 +11,52 @@ import random import numpy as np +log_folder = "log" + +if not os.path.exists(log_folder): + os.makedirs(log_folder) + +log_file = os.path.join(log_folder, f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler() + ] +) + +logger = logging.getLogger(__name__) + +class LoggerWriter: + def __init__(self, logger, level): + self.logger = logger + self.level = level + + def write(self, message): + if message.rstrip() != "": + self.logger.log(self.level, message.rstrip()) + + def flush(self): + pass + + +sys.stdout = LoggerWriter(logger, logging.INFO) +sys.stderr = LoggerWriter(logger, logging.ERROR) + + +def train_model(): + try: + logger.info("Starting model training...") + for epoch in range(10): + logger.info(f"Epoch {epoch + 1}/10: Training...") + logger.info("Model training completed successfully.") + except Exception as e: + logger.error("An error occurred during training:", exc_info=True) + if __name__ == '__main__': + train_model() fix_seed = 2021 random.seed(fix_seed) torch.manual_seed(fix_seed)