Skip to content

Commit 9733de3

Browse files
committed
pre-commit
1 parent 6cbb430 commit 9733de3

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

pomdp_py/problems/rocksample/rocksample_problem.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,14 @@ def __init__(self, position, rocktypes, terminal=False, removed_rocks=None):
9292
self.removed_rocks = removed_rocks
9393

9494
def __hash__(self):
95-
return hash((self.position, self.rocktypes, self.terminal, tuple(sorted(self.removed_rocks))))
95+
return hash(
96+
(
97+
self.position,
98+
self.rocktypes,
99+
self.terminal,
100+
tuple(sorted(self.removed_rocks)),
101+
)
102+
)
96103

97104
def __eq__(self, other):
98105
if isinstance(other, State):
@@ -118,7 +125,7 @@ def __repr__(self):
118125
return "State(%s | %s | %s)" % (
119126
str(self.position),
120127
str(rocks_status),
121-
str(self.terminal)
128+
str(self.terminal),
122129
)
123130

124131

@@ -339,7 +346,9 @@ def sample(self, state, action, next_state, normalized=False, **kwargs):
339346
# Bad rock
340347
return -10
341348
else:
342-
return -100 # Large penalty for sampling at non-rock position (defensive programming)
349+
return (
350+
-100
351+
) # Large penalty for sampling at non-rock position (defensive programming)
343352

344353
elif isinstance(action, MoveAction):
345354
if self._in_exit_area(next_state.position):
@@ -578,7 +587,7 @@ def calculate_std(values):
578587
return 0.0
579588
mean = sum(values) / len(values)
580589
variance = sum((x - mean) ** 2 for x in values) / (len(values) - 1)
581-
return variance ** 0.5
590+
return variance**0.5
582591

583592

584593
def create_instance(n, k, **kwargs):
@@ -623,11 +632,13 @@ def benchmark(verbose=False):
623632
exploration_const=exploration_const,
624633
rollout_policy=rocksample.agent.policy_model,
625634
num_visits_init=1,
626-
show_progress=verbose
635+
show_progress=verbose,
627636
)
628637

629638
# Run the test planner
630-
tt, ttd = test_planner(rocksample, pomcp, nsteps=200, discount=0.95, verbose=verbose)
639+
tt, ttd = test_planner(
640+
rocksample, pomcp, nsteps=200, discount=0.95, verbose=verbose
641+
)
631642

632643
total_rewards.append(tt)
633644
total_discounted_rewards.append(ttd)
@@ -636,20 +647,26 @@ def benchmark(verbose=False):
636647

637648
# Calculate averages
638649
avg_total_reward = sum(total_rewards) / len(total_rewards)
639-
avg_discounted_reward = sum(total_discounted_rewards) / len(total_discounted_rewards)
650+
avg_discounted_reward = sum(total_discounted_rewards) / len(
651+
total_discounted_rewards
652+
)
640653

641-
print("\n" + "="*50)
654+
print("\n" + "=" * 50)
642655
print(f"FINAL RESULTS ({k_runs} runs)")
643-
print("="*50)
656+
print("=" * 50)
644657
print(f"Average total reward: {avg_total_reward:.3f}")
645658
print(f"Average discounted reward: {avg_discounted_reward:.3f}")
646659
print(f"Standard deviation of total reward: {calculate_std(total_rewards):.3f}")
647-
print(f"Standard deviation of discounted reward: {calculate_std(total_discounted_rewards):.3f}")
660+
print(
661+
"Standard deviation of discounted reward:"
662+
f" {calculate_std(total_discounted_rewards):.3f}"
663+
)
648664
print(f"Min total reward: {min(total_rewards)}")
649665
print(f"Max total reward: {max(total_rewards)}")
650666
print(f"Min discounted reward: {min(total_discounted_rewards):.3f}")
651667
print(f"Max discounted reward: {max(total_discounted_rewards):.3f}")
652-
print("="*50)
668+
print("=" * 50)
669+
653670

654671
def main(argv=None):
655672
parser = argparse.ArgumentParser(description="RockSample Problem Runner")
@@ -661,7 +678,7 @@ def main(argv=None):
661678
parser.add_argument(
662679
"--verbose",
663680
action="store_true",
664-
help="Enable verbose output during the benchmark."
681+
help="Enable verbose output during the benchmark.",
665682
)
666683
args = parser.parse_args(argv)
667684

@@ -677,7 +694,7 @@ def main(argv=None):
677694
exploration_const=10,
678695
rollout_policy=rocksample.agent.policy_model,
679696
num_visits_init=1,
680-
show_progress=True
697+
show_progress=True,
681698
)
682699
test_planner(rocksample, pomcp, nsteps=200, discount=0.95, verbose=True)
683700

0 commit comments

Comments
 (0)