44from contextlib import suppress
55from functools import partial
66from operator import itemgetter
7+ from typing import Any , Callable , Dict , List , Set , Tuple , Union
78
89import numpy as np
910
1213from adaptive .utils import cache_latest , named_product , restore
1314
1415
15- def dispatch (child_functions , arg ) :
16+ def dispatch (child_functions : List [ Callable ] , arg : Any ) -> Union [ Any ] :
1617 index , x = arg
1718 return child_functions [index ](x )
1819
@@ -68,7 +69,9 @@ class BalancingLearner(BaseLearner):
6869 behave in an undefined way. Change the `strategy` in that case.
6970 """
7071
71- def __init__ (self , learners , * , cdims = None , strategy = "loss_improvements" ):
72+ def __init__ (
73+ self , learners : List [BaseLearner ], * , cdims = None , strategy = "loss_improvements"
74+ ) -> None :
7275 self .learners = learners
7376
7477 # Naively we would make 'function' a method, but this causes problems
@@ -89,21 +92,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
8992 self .strategy = strategy
9093
9194 @property
92- def data (self ):
95+ def data (self ) -> Dict [ Tuple [ int , Any ], Any ] :
9396 data = {}
9497 for i , l in enumerate (self .learners ):
9598 data .update ({(i , p ): v for p , v in l .data .items ()})
9699 return data
97100
98101 @property
99- def pending_points (self ):
102+ def pending_points (self ) -> Set [ Tuple [ int , Any ]] :
100103 pending_points = set ()
101104 for i , l in enumerate (self .learners ):
102105 pending_points .update ({(i , p ) for p in l .pending_points })
103106 return pending_points
104107
105108 @property
106- def npoints (self ):
109+ def npoints (self ) -> int :
107110 return sum (l .npoints for l in self .learners )
108111
109112 @property
@@ -135,7 +138,9 @@ def strategy(self, strategy):
135138 ' strategy="npoints", or strategy="cycle" is implemented.'
136139 )
137140
138- def _ask_and_tell_based_on_loss_improvements (self , n ):
141+ def _ask_and_tell_based_on_loss_improvements (
142+ self , n : int
143+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
139144 selected = [] # tuples ((learner_index, point), loss_improvement)
140145 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
141146 for _ in range (n ):
@@ -158,7 +163,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
158163 points , loss_improvements = map (list , zip (* selected ))
159164 return points , loss_improvements
160165
161- def _ask_and_tell_based_on_loss (self , n ):
166+ def _ask_and_tell_based_on_loss (
167+ self , n : int
168+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
162169 selected = [] # tuples ((learner_index, point), loss_improvement)
163170 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
164171 for _ in range (n ):
@@ -179,7 +186,9 @@ def _ask_and_tell_based_on_loss(self, n):
179186 points , loss_improvements = map (list , zip (* selected ))
180187 return points , loss_improvements
181188
182- def _ask_and_tell_based_on_npoints (self , n ):
189+ def _ask_and_tell_based_on_npoints (
190+ self , n : int
191+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
183192 selected = [] # tuples ((learner_index, point), loss_improvement)
184193 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
185194 for _ in range (n ):
@@ -195,7 +204,9 @@ def _ask_and_tell_based_on_npoints(self, n):
195204 points , loss_improvements = map (list , zip (* selected ))
196205 return points , loss_improvements
197206
198- def _ask_and_tell_based_on_cycle (self , n ):
207+ def _ask_and_tell_based_on_cycle (
208+ self , n : int
209+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
199210 points , loss_improvements = [], []
200211 for _ in range (n ):
201212 index = next (self ._cycle )
@@ -206,7 +217,9 @@ def _ask_and_tell_based_on_cycle(self, n):
206217
207218 return points , loss_improvements
208219
209- def ask (self , n , tell_pending = True ):
220+ def ask (
221+ self , n : int , tell_pending : bool = True
222+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
210223 """Chose points for learners."""
211224 if n == 0 :
212225 return [], []
@@ -217,20 +230,20 @@ def ask(self, n, tell_pending=True):
217230 else :
218231 return self ._ask_and_tell (n )
219232
220- def tell (self , x , y ) :
233+ def tell (self , x : Tuple [ int , Any ], y : Any ) -> None :
221234 index , x = x
222235 self ._ask_cache .pop (index , None )
223236 self ._loss .pop (index , None )
224237 self ._pending_loss .pop (index , None )
225238 self .learners [index ].tell (x , y )
226239
227- def tell_pending (self , x ) :
240+ def tell_pending (self , x : Tuple [ int , Any ]) -> None :
228241 index , x = x
229242 self ._ask_cache .pop (index , None )
230243 self ._loss .pop (index , None )
231244 self .learners [index ].tell_pending (x )
232245
233- def _losses (self , real = True ):
246+ def _losses (self , real : bool = True ) -> List [ float ] :
234247 losses = []
235248 loss_dict = self ._loss if real else self ._pending_loss
236249
@@ -242,7 +255,7 @@ def _losses(self, real=True):
242255 return losses
243256
244257 @cache_latest
245- def loss (self , real = True ):
258+ def loss (self , real : bool = True ) -> Union [ float ] :
246259 losses = self ._losses (real )
247260 return max (losses )
248261
@@ -325,7 +338,9 @@ def remove_unfinished(self):
325338 learner .remove_unfinished ()
326339
327340 @classmethod
328- def from_product (cls , f , learner_type , learner_kwargs , combos ):
341+ def from_product (
342+ cls , f , learner_type , learner_kwargs , combos
343+ ) -> "BalancingLearner" :
329344 """Create a `BalancingLearner` with learners of all combinations of
330345 named variables’ values. The `cdims` will be set correctly, so calling
331346 `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -372,7 +387,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
372387 learners .append (learner )
373388 return cls (learners , cdims = arguments )
374389
375- def save (self , fname , compress = True ):
390+ def save (self , fname : Callable , compress : bool = True ) -> None :
376391 """Save the data of the child learners into pickle files
377392 in a directory.
378393
@@ -410,7 +425,7 @@ def save(self, fname, compress=True):
410425 for l in self .learners :
411426 l .save (fname (l ), compress = compress )
412427
413- def load (self , fname , compress = True ):
428+ def load (self , fname : Callable , compress : bool = True ) -> None :
414429 """Load the data of the child learners from pickle files
415430 in a directory.
416431
0 commit comments