1818
1919import numpy as np
2020from bears import FileMetadata
21- from bears .constants import REMOTE_STORAGES
21+ from bears .constants import REMOTE_STORAGES , Parallelize
2222from bears .core .frame import ScalableDataFrame
2323from bears .util import (
2424 ActorProxy ,
2929 String ,
3030 Timeout ,
3131 Timer ,
32+ accumulate ,
33+ dispatch_executor ,
3234 get_default ,
3335 get_result ,
3436 is_done ,
3537 safe_validate_arguments ,
3638 set_param_from_alias ,
39+ stop_executor ,
3740 wait ,
3841)
3942from bears .util .concurrency ._processes import actor
@@ -149,7 +152,7 @@ class MultiProcessEvaluator(Evaluator):
149152
150153 nested_evaluator_name : Optional [str ] = None
151154 num_models : Optional [conint (ge = 1 )] = None
152- mp_context : Literal ["spawn" , "fork" , "forkserver" ] = "fork "
155+ mp_context : Literal ["spawn" , "fork" , "forkserver" ] = "spawn "
153156 model : Optional [List [Any ]] = None ## Stores the actor proxies
154157 progress_update_frequency : confloat (ge = 0.0 ) = 15.0
155158 ## By default, do not cache the model:
@@ -200,19 +203,29 @@ def _load_model(
200203 unit = "actors" ,
201204 )
202205 nested_evaluator_params : Dict = self ._create_nested_evaluator_params (** kwargs )
203- actors : List [Any ] = []
204206
207+ ## TODO: fix the spawn creation logic to be faster. Currently, it is super slow
208+ ## so we have to use a threadpool to create them.
209+ actor_creation_executor = dispatch_executor (
210+ parallelize = Parallelize .threads ,
211+ max_workers = min (num_actors , 20 ),
212+ )
213+ actors : List [Any ] = []
205214 for actor_i in range (num_actors ):
206215 actors .append (
207- ProcessAlgorithmEvaluator .remote (
216+ actor_creation_executor .submit (
217+ ProcessAlgorithmEvaluator .remote ,
208218 evaluator = nested_evaluator_params ,
209219 actor = (actor_i , num_actors ),
210220 verbosity = self .verbosity ,
211221 mp_context = self .mp_context ,
212222 )
213223 )
214- actors_progress_bar .update (1 )
215224 time .sleep (0.100 )
225+ for actor_i , actor_future in enumerate (actors ):
226+ actors [actor_i ] = actor_future .result ()
227+ actors_progress_bar .update (1 )
228+ stop_executor (actor_creation_executor )
216229 if len (actors ) != num_actors :
217230 msg : str = f"Creation of { num_actors - len (actors )} actors failed"
218231 actors_progress_bar .failed (msg )
@@ -228,13 +241,24 @@ def cleanup_model(self):
228241
229242 def _kill_actors (self ):
230243 """Kill all process actors and clean up resources."""
244+
245+ def _stop_actor (actor : ActorProxy ):
246+ actor .stop (cancel_futures = True )
247+ del actor
248+ gc .collect ()
249+
231250 try :
232251 if self .model is not None :
233252 actors : List [ActorProxy ] = self .model
234253 self .model = None
235- for actor in actors :
236- actor .stop (cancel_futures = True )
237- del actor
254+ ## TODO: fix the spawn stop logic to be faster. Currently, it is super slow
255+ ## so we have to use a threadpool to stop them.
256+ actor_stop_executor = dispatch_executor (
257+ parallelize = Parallelize .threads ,
258+ max_workers = min (len (actors ), 20 ),
259+ )
260+ accumulate ([actor_stop_executor .submit (_stop_actor , actor ) for actor in actors ])
261+ stop_executor (actor_stop_executor )
238262 del actors
239263 finally :
240264 gc .collect ()
0 commit comments