33import time
44import warnings
55from functools import partial
6- from typing import Any , ClassVar , Optional , TypeVar , Union
6+ from typing import Any , ClassVar , TypeVar
77
88import numpy as np
99import torch as th
@@ -57,21 +57,21 @@ class ARS(BaseAlgorithm):
5757
5858 def __init__ (
5959 self ,
60- policy : Union [ str , type [ARSPolicy ] ],
61- env : Union [ GymEnv , str ] ,
60+ policy : str | type [ARSPolicy ],
61+ env : GymEnv | str ,
6262 n_delta : int = 8 ,
63- n_top : Optional [ int ] = None ,
64- learning_rate : Union [ float , Schedule ] = 0.02 ,
65- delta_std : Union [ float , Schedule ] = 0.05 ,
63+ n_top : int | None = None ,
64+ learning_rate : float | Schedule = 0.02 ,
65+ delta_std : float | Schedule = 0.05 ,
6666 zero_policy : bool = True ,
6767 alive_bonus_offset : float = 0 ,
6868 n_eval_episodes : int = 1 ,
69- policy_kwargs : Optional [ dict [str , Any ]] = None ,
69+ policy_kwargs : dict [str , Any ] | None = None ,
7070 stats_window_size : int = 100 ,
71- tensorboard_log : Optional [ str ] = None ,
72- seed : Optional [ int ] = None ,
71+ tensorboard_log : str | None = None ,
72+ seed : int | None = None ,
7373 verbose : int = 0 ,
74- device : Union [ th .device , str ] = "cpu" ,
74+ device : th .device | str = "cpu" ,
7575 _init_setup_model : bool = True ,
7676 ):
7777 super ().__init__ (
@@ -137,7 +137,7 @@ def _mimic_monitor_wrapper(self, episode_rewards: np.ndarray, episode_lengths: n
137137 # Mimic Monitor Wrapper
138138 infos = [
139139 {"episode" : {"r" : episode_reward , "l" : episode_length }}
140- for episode_reward , episode_length in zip (episode_rewards , episode_lengths )
140+ for episode_reward , episode_length in zip (episode_rewards , episode_lengths , strict = True )
141141 ]
142142
143143 self ._update_info_buffer (infos )
@@ -163,7 +163,7 @@ def _trigger_callback(
163163 callback .on_step ()
164164
165165 def evaluate_candidates (
166- self , candidate_weights : th .Tensor , callback : BaseCallback , async_eval : Optional [ AsyncEval ]
166+ self , candidate_weights : th .Tensor , callback : BaseCallback , async_eval : AsyncEval | None
167167 ) -> th .Tensor :
168168 """
169169 Evaluate each candidate.
@@ -257,7 +257,7 @@ def dump_logs(self) -> None:
257257 self .logger .record ("time/total_timesteps" , self .num_timesteps , exclude = "tensorboard" )
258258 self .logger .dump (step = self .num_timesteps )
259259
260- def _do_one_update (self , callback : BaseCallback , async_eval : Optional [ AsyncEval ] ) -> None :
260+ def _do_one_update (self , callback : BaseCallback , async_eval : AsyncEval | None ) -> None :
261261 """
262262 Sample new candidates, evaluate them and then update current policy.
263263
@@ -312,7 +312,7 @@ def learn(
312312 log_interval : int = 1 ,
313313 tb_log_name : str = "ARS" ,
314314 reset_num_timesteps : bool = True ,
315- async_eval : Optional [ AsyncEval ] = None ,
315+ async_eval : AsyncEval | None = None ,
316316 progress_bar : bool = False ,
317317 ) -> SelfARS :
318318 """
@@ -353,9 +353,9 @@ def learn(
353353
354354 def set_parameters (
355355 self ,
356- load_path_or_dict : Union [ str , dict [str , dict ] ],
356+ load_path_or_dict : str | dict [str , dict ],
357357 exact_match : bool = True ,
358- device : Union [ th .device , str ] = "auto" ,
358+ device : th .device | str = "auto" ,
359359 ) -> None :
360360 # Patched set_parameters() to handle ARS linear policy saved with sb3-contrib < 1.7.0
361361 params = None
0 commit comments