Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions agents/dqn_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse
import sys
import time
import os
from stable_baselines3 import DQN
Expand Down Expand Up @@ -105,13 +107,42 @@ def train_dqn(total_timesteps=100_000, save_path=None, continue_training=False):
print(f"Training completed in {elapsed:.1f}s")
return model

def get_args():
"""
arguments parser in main function
python -m agents/*_agent.py -c -v -m -t 1200000
"""
parser = argparse.ArgumentParser()
parser.description='please enter optional parameters: train, visualize, continue, timestamp ...'
# add_argument:
# default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false)
parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=50_000_000)
parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false")
parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false")
parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true")

# parser result
args = parser.parse_args()

return args

if __name__ == "__main__":
# Hyperparameters
total_timesteps = 50_000_000
continue_training = False
do_train = True
do_visualize = True
"""
Main function for agent training and visualizing
"""
# args parsing
args = get_args()
total_timesteps = args.total_timesteps
continue_training = args.continue_training
do_train = args.do_train
do_visualize = args.do_visualize

print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}')
# user confirmation
user_confirm = input("[Note] Are you sure? (y/n) ")
if user_confirm.lower() == "n":
print("[WARNING] Please input the args again. Exiting !")
sys.exit(1)

if do_train:
trained = train_dqn(
Expand Down
44 changes: 39 additions & 5 deletions agents/dqn_masked_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import os
import sys
import time
import numpy as np

Expand Down Expand Up @@ -133,12 +135,44 @@ def train_masked_dqn(total_timesteps: int = 500_000, continue_training: bool = F
return model


if __name__ == "__main__":
total_timesteps = 20_000_000
continue_training = False
do_train = True
do_visualize = True
def get_args():
"""
arguments parser in main function
python -m agents/dqn_*_agent.py -c -v -m -t 1200000
"""
parser = argparse.ArgumentParser()
parser.description='please enter optional parameters: train, visualize, continue, timestamp ...'
# add_argument:
# default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false)
parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=20_000_000)
parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false")
parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false")
parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true")

# parser result
args = parser.parse_args()

return args

if __name__ == "__main__":
"""
Main function for agent training and visualizing
"""
# args parsing
args = get_args()
total_timesteps = args.total_timesteps
continue_training = args.continue_training
do_train = args.do_train
do_visualize = args.do_visualize

print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}')
# user confirmation
user_confirm = input("[Note] Are you sure? (y/n) ")
if user_confirm.lower() == "n":
print("[WARNING] Please input the args again. Exiting !")
sys.exit(1)

# start to execute main process
if do_train:
train_masked_dqn(total_timesteps, continue_training)

Expand Down
39 changes: 38 additions & 1 deletion agents/ppo_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import os
import sys
from typing import Callable, Optional
Expand Down Expand Up @@ -95,13 +96,49 @@ def train_ppo(
print(f"[ppo] Training done: {final_path}.zip")
return model

def get_args():
"""
arguments parser in main function
python -m agents/dqn_*_agent.py -c -v -m -t 1200000
"""
parser = argparse.ArgumentParser()
parser.description='please enter optional parameters: train, visualize, continue, timestamp ...'
# add_argument:
# default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false)
parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=50_000_000)
parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false")
parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false")
parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true")
parser.add_argument("-e", "--env", help="num of env", dest="num_envs", type=int, default=8)

# parser result
args = parser.parse_args()

return args

if __name__ == "__main__":
"""
Main function for agent training and visualizing
"""
# args parsing
args = get_args()
total_timesteps = args.total_timesteps
continue_training = args.continue_training
do_train = args.do_train
do_visualize = args.do_visualize
num_envs = args.num_envs

print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}\n\t{num_envs=}')
# user confirmation
user_confirm = input("[Note] Are you sure? (y/n) ")
if user_confirm.lower() == "n":
print("[WARNING] Please input the args again. Exiting !")
sys.exit(1)

# Configuration
num_envs = 8
total_timesteps = 50_000_000
continue_training = False

do_train = True
do_visualize = True

Expand Down
44 changes: 37 additions & 7 deletions agents/ppo_masked_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import os
import sys
from typing import Callable, Optional
Expand Down Expand Up @@ -108,15 +109,44 @@ def train_masked_ppo(
print(f"[masked ppo] Training done: {final_path}.zip")
return model

def get_args():
"""
arguments parser in main function
python -m agents/*_agent.py -c -v -m -t 1200000
"""
parser = argparse.ArgumentParser()
parser.description='please enter optional parameters: train, visualize, continue, timestamp ...'
# add_argument:
# default: total_timesteps(20_000_000), do_train(true), do_visualize(true), continue_training(false)
parser.add_argument("-t", "--timesteps", help="total_timesteps for training", dest="total_timesteps", type=int, default=50_000_000)
parser.add_argument("-m", "--train", help="train mode", dest="do_train", action="store_false")
parser.add_argument("-v", "--visualize", help="visualize result", dest="do_visualize", action="store_false")
parser.add_argument("-c", "--continue", help="continue training or not", dest="continue_training", action="store_true")
parser.add_argument("-e", "--env", help="num of env", dest="num_envs", type=int, default=8)

# parser result
args = parser.parse_args()

return args

if __name__ == "__main__":
# Configuration
num_envs = 8
total_timesteps = 50_000_000
continue_training = False

do_train = True
do_visualize = True
"""
Main function for agent training and visualizing
"""
# args parsing
args = get_args()
total_timesteps = args.total_timesteps
continue_training = args.continue_training
do_train = args.do_train
do_visualize = args.do_visualize
num_envs = args.num_envs

print(f'[INFO] Args:\n\t{total_timesteps=}\n\t{continue_training=}\n\t{do_train=}\n\t{do_visualize=}\n\t{num_envs=}')
# user confirmation
user_confirm = input("[Note] Are you sure? (y/n) ")
if user_confirm.lower() == "n":
print("[WARNING] Please input the args again. Exiting !")
sys.exit(1)

if do_train:
pretrained = os.path.join(MODELS_DIR, "final_masked_ppo_model.zip")
Expand Down