@@ -92,7 +92,14 @@ def __init__(self, position, rocktypes, terminal=False, removed_rocks=None):
9292 self .removed_rocks = removed_rocks
9393
9494 def __hash__ (self ):
95- return hash ((self .position , self .rocktypes , self .terminal , tuple (sorted (self .removed_rocks ))))
95+ return hash (
96+ (
97+ self .position ,
98+ self .rocktypes ,
99+ self .terminal ,
100+ tuple (sorted (self .removed_rocks )),
101+ )
102+ )
96103
97104 def __eq__ (self , other ):
98105 if isinstance (other , State ):
@@ -118,7 +125,7 @@ def __repr__(self):
118125 return "State(%s | %s | %s)" % (
119126 str (self .position ),
120127 str (rocks_status ),
121- str (self .terminal )
128+ str (self .terminal ),
122129 )
123130
124131
@@ -339,7 +346,9 @@ def sample(self, state, action, next_state, normalized=False, **kwargs):
339346 # Bad rock
340347 return - 10
341348 else :
342- return - 100 # Large penalty for sampling at non-rock position (defensive programming)
349+ return (
350+ - 100
351+ ) # Large penalty for sampling at non-rock position (defensive programming)
343352
344353 elif isinstance (action , MoveAction ):
345354 if self ._in_exit_area (next_state .position ):
@@ -578,7 +587,7 @@ def calculate_std(values):
578587 return 0.0
579588 mean = sum (values ) / len (values )
580589 variance = sum ((x - mean ) ** 2 for x in values ) / (len (values ) - 1 )
581- return variance ** 0.5
590+ return variance ** 0.5
582591
583592
584593def create_instance (n , k , ** kwargs ):
@@ -623,11 +632,13 @@ def benchmark(verbose=False):
623632 exploration_const = exploration_const ,
624633 rollout_policy = rocksample .agent .policy_model ,
625634 num_visits_init = 1 ,
626- show_progress = verbose
635+ show_progress = verbose ,
627636 )
628637
629638 # Run the test planner
630- tt , ttd = test_planner (rocksample , pomcp , nsteps = 200 , discount = 0.95 , verbose = verbose )
639+ tt , ttd = test_planner (
640+ rocksample , pomcp , nsteps = 200 , discount = 0.95 , verbose = verbose
641+ )
631642
632643 total_rewards .append (tt )
633644 total_discounted_rewards .append (ttd )
@@ -636,20 +647,26 @@ def benchmark(verbose=False):
636647
637648 # Calculate averages
638649 avg_total_reward = sum (total_rewards ) / len (total_rewards )
639- avg_discounted_reward = sum (total_discounted_rewards ) / len (total_discounted_rewards )
650+ avg_discounted_reward = sum (total_discounted_rewards ) / len (
651+ total_discounted_rewards
652+ )
640653
641- print ("\n " + "=" * 50 )
654+ print ("\n " + "=" * 50 )
642655 print (f"FINAL RESULTS ({ k_runs } runs)" )
643- print ("=" * 50 )
656+ print ("=" * 50 )
644657 print (f"Average total reward: { avg_total_reward :.3f} " )
645658 print (f"Average discounted reward: { avg_discounted_reward :.3f} " )
646659 print (f"Standard deviation of total reward: { calculate_std (total_rewards ):.3f} " )
647- print (f"Standard deviation of discounted reward: { calculate_std (total_discounted_rewards ):.3f} " )
660+ print (
661+ "Standard deviation of discounted reward:"
662+ f" { calculate_std (total_discounted_rewards ):.3f} "
663+ )
648664 print (f"Min total reward: { min (total_rewards )} " )
649665 print (f"Max total reward: { max (total_rewards )} " )
650666 print (f"Min discounted reward: { min (total_discounted_rewards ):.3f} " )
651667 print (f"Max discounted reward: { max (total_discounted_rewards ):.3f} " )
652- print ("=" * 50 )
668+ print ("=" * 50 )
669+
653670
654671def main (argv = None ):
655672 parser = argparse .ArgumentParser (description = "RockSample Problem Runner" )
@@ -661,7 +678,7 @@ def main(argv=None):
661678 parser .add_argument (
662679 "--verbose" ,
663680 action = "store_true" ,
664- help = "Enable verbose output during the benchmark."
681+ help = "Enable verbose output during the benchmark." ,
665682 )
666683 args = parser .parse_args (argv )
667684
@@ -677,7 +694,7 @@ def main(argv=None):
677694 exploration_const = 10 ,
678695 rollout_policy = rocksample .agent .policy_model ,
679696 num_visits_init = 1 ,
680- show_progress = True
697+ show_progress = True ,
681698 )
682699 test_planner (rocksample , pomcp , nsteps = 200 , discount = 0.95 , verbose = True )
683700
0 commit comments