22import logging
33import os
44from collections import defaultdict
5+ from pathlib import Path
56
67import numpy as np
78from flatland .envs .step_utils .states import TrainState
@@ -64,8 +65,7 @@ def run_scenario(self, scenario_id: str, submission_id: str):
6465 self .exec (generate_policy_args_one_malfunction , scenario_id , submission_id , f"{ submission_id } /{ self .test_id } /{ scenario_id } /with_malfunction" )
6566
6667 # no malfunction
67- trajectory_no_malfunction = Trajectory (data_dir = data_dir_no_malfunction , ep_id = scenario_id )
68- trajectory_no_malfunction .load ()
68+ trajectory_no_malfunction = Trajectory .load_existing (data_dir = Path (data_dir_no_malfunction ), ep_id = scenario_id )
6969 num_agents = trajectory_no_malfunction .trains_rewards_dones_infos ["agent_id" ].max () + 1
7070 for _ , r in trajectory_no_malfunction .trains_rewards_dones_infos .iterrows ():
7171 assert r ["info" ]["malfunction" ] == 0
@@ -79,8 +79,7 @@ def run_scenario(self, scenario_id: str, submission_id: str):
7979 num_betroffen1 = np .sum (betroffen1 )
8080 logger .info (f"num_betroffen1 { num_betroffen1 } " )
8181
82- trajectory_with_malfunction = Trajectory (data_dir = data_dir_with_malfunction , ep_id = scenario_id )
83- trajectory_with_malfunction .load ()
82+ trajectory_with_malfunction = Trajectory .load_existing (data_dir = Path (data_dir_with_malfunction ), ep_id = scenario_id )
8483 malfunction_agents = defaultdict (list )
8584 for _ , r in trajectory_with_malfunction .trains_rewards_dones_infos .iterrows ():
8685 if r ["info" ]["malfunction" ] > 0 :
0 commit comments