From 119ae47b742517ec6ecdf90dd8d7123ff41ea2c5 Mon Sep 17 00:00:00 2001 From: wangqiwen Date: Sat, 26 Apr 2025 15:42:09 +0800 Subject: [PATCH] feat(command): add args parsing for all RL agent --- agents/dqn_agent.py | 41 ++++++++++++++++++++++++++++++----- agents/dqn_masked_agent.py | 44 +++++++++++++++++++++++++++++++++----- agents/ppo_agent.py | 39 ++++++++++++++++++++++++++++++++- agents/ppo_masked_agent.py | 44 ++++++++++++++++++++++++++++++++------ 4 files changed, 150 insertions(+), 18 deletions(-) diff --git a/agents/dqn_agent.py b/agents/dqn_agent.py index e7bc228..e9a16c2 100644 --- a/agents/dqn_agent.py +++ b/agents/dqn_agent.py @@ -1,3 +1,5 @@ +import argparse +import sys import time import os from stable_baselines3 import DQN @@ -105,13 +107,42 @@ def train_dqn(total_timesteps=100_000, save_path=None, continue_training=False): print(f"Training completed in {elapsed:.1f}s") return model +def get_args(): + """ + arguments parser in main function + python -m agents/*_agent.py -c -v -m -t 1200000 + """ + parser = argparse.ArgumentParser() + parser.description='please enter optional parameters: train, visualize, continue, timestamp ...' + # add_argument: + # default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false) + parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=50_000_000) + parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false") + parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false") + parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true") + + # parser result + args = parser.parse_args() + + return args if __name__ == "__main__": - # Hyperparameters - total_timesteps = 50_000_000 - continue_training = False - do_train = True - do_visualize = True + """ + Main function for agent training and visualizing + """ + # args parsing + args = get_args() + total_timesteps = args.total_timesteps + continue_training = args.continue_training + do_train = args.do_train + do_visualize = args.do_visualize + + print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}') + # user confirmation + user_confirm = input("[Note] Are you sure? (y/n) ") + if user_confirm.lower() == "n": + print("[WARNING] Please input the args again. Exiting !") + sys.exit(1) if do_train: trained = train_dqn( diff --git a/agents/dqn_masked_agent.py b/agents/dqn_masked_agent.py index 8b88084..d74d7c5 100644 --- a/agents/dqn_masked_agent.py +++ b/agents/dqn_masked_agent.py @@ -1,4 +1,6 @@ +import argparse import os +import sys import time import numpy as np @@ -133,12 +135,44 @@ def train_masked_dqn(total_timesteps: int = 500_000, continue_training: bool = F return model -if __name__ == "__main__": - total_timesteps = 20_000_000 - continue_training = False - do_train = True - do_visualize = True +def get_args(): + """ + arguments parser in main function + python -m agents/dqn_*_agent.py -c -v -m -t 1200000 + """ + parser = argparse.ArgumentParser() + parser.description='please enter optional parameters: train, visualize, continue, timestamp ...' + # add_argument: + # default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false) + parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=20_000_000) + parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false") + parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false") + parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true") + + # parser result + args = parser.parse_args() + return args + +if __name__ == "__main__": + """ + Main function for agent training and visualizing + """ + # args parsing + args = get_args() + total_timesteps = args.total_timesteps + continue_training = args.continue_training + do_train = args.do_train + do_visualize = args.do_visualize + + print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}') + # user confirmation + user_confirm = input("[Note] Are you sure? (y/n) ") + if user_confirm.lower() == "n": + print("[WARNING] Please input the args again. Exiting !") + sys.exit(1) + + # start to execute main process if do_train: train_masked_dqn(total_timesteps, continue_training) diff --git a/agents/ppo_agent.py b/agents/ppo_agent.py index f88fa76..98dfeed 100644 --- a/agents/ppo_agent.py +++ b/agents/ppo_agent.py @@ -1,3 +1,4 @@ +import argparse import os import sys from typing import Callable, Optional @@ -95,13 +96,49 @@ def train_ppo( print(f"[ppo] Training done: {final_path}.zip") return model +def get_args(): + """ + arguments parser in main function + python -m agents/dqn_*_agent.py -c -v -m -t 1200000 + """ + parser = argparse.ArgumentParser() + parser.description='please enter optional parameters: train, visualize, continue, timestamp ...' + # add_argument: + # default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false) + parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=50_000_000) + parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false") + parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false") + parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true") + parser.add_argument("-e", "--env", help="num of env", dest="num_envs", type=int, default=8) + + # parser result + args = parser.parse_args() + + return args if __name__ == "__main__": + """ + Main function for agent training and visualizing + """ + # args parsing + args = get_args() + total_timesteps = args.total_timesteps + continue_training = args.continue_training + do_train = args.do_train + do_visualize = args.do_visualize + num_envs = args.num_envs + + print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}\n\t{num_envs=}') + # user confirmation + user_confirm = input("[Note] Are you sure? (y/n) ") + if user_confirm.lower() == "n": + print("[WARNING] Please input the args again. Exiting !") + sys.exit(1) + # Configuration num_envs = 8 total_timesteps = 50_000_000 continue_training = False - do_train = True do_visualize = True diff --git a/agents/ppo_masked_agent.py b/agents/ppo_masked_agent.py index 4b442c7..0f0268b 100644 --- a/agents/ppo_masked_agent.py +++ b/agents/ppo_masked_agent.py @@ -1,3 +1,4 @@ +import argparse import os import sys from typing import Callable, Optional @@ -108,15 +109,44 @@ def train_masked_ppo( print(f"[masked ppo] Training done: {final_path}.zip") return model +def get_args(): + """ + arguments parser in main function + python -m agents/*_agent.py -c -v -m -t 1200000 + """ + parser = argparse.ArgumentParser() + parser.description='please enter optional parameters: train, visualize, continue, timestamp ...' + # add_argument: + # default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false) + parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=50_000_000) + parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false") + parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false") + parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true") + parser.add_argument("-e", "--env", help="num of env", dest="num_envs", type=int, default=8) + + # parser result + args = parser.parse_args() + + return args if __name__ == "__main__": - # Configuration - num_envs = 8 - total_timesteps = 50_000_000 - continue_training = False - - do_train = True - do_visualize = True + """ + Main function for agent training and visualizing + """ + # args parsing + args = get_args() + total_timesteps = args.total_timesteps + continue_training = args.continue_training + do_train = args.do_train + do_visualize = args.do_visualize + num_envs = args.num_envs + + print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}\n\t{num_envs=}') + # user confirmation + user_confirm = input("[Note] Are you sure? (y/n) ") + if user_confirm.lower() == "n": + print("[WARNING] Please input the args again. Exiting !") + sys.exit(1) if do_train: pretrained = os.path.join(MODELS_DIR, "final_masked_ppo_model.zip")