1+ import collections .abc
12import itertools
23import math
3- from collections . abc import Iterable
4+ import numbers
45from copy import deepcopy
6+ from typing import (
7+ Any ,
8+ Callable ,
9+ Dict ,
10+ Iterable ,
11+ List ,
12+ Literal ,
13+ Optional ,
14+ Sequence ,
15+ Tuple ,
16+ Union ,
17+ )
518
619import cloudpickle
720import numpy as np
8- import sortedcollections
9- import sortedcontainers
21+ from sortedcollections . recipes import ItemSortedDict
22+ from sortedcontainers . sorteddict import SortedDict
1023
1124from adaptive .learner .base_learner import BaseLearner , uses_nth_neighbors
1225from adaptive .learner .learnerND import volume
1326from adaptive .learner .triangulation import simplex_volume_in_embedding
1427from adaptive .notebook_integration import ensure_holoviews
28+ from adaptive .types import Float
1529from adaptive .utils import cache_latest
1630
31+ Point = Tuple [Float , Float ]
32+
1733
1834@uses_nth_neighbors (0 )
19- def uniform_loss (xs , ys ) :
35+ def uniform_loss (xs : Point , ys : Any ) -> Float :
2036 """Loss function that samples the domain uniformly.
2137
2238 Works with `~adaptive.Learner1D` only.
@@ -36,17 +52,20 @@ def uniform_loss(xs, ys):
3652
3753
3854@uses_nth_neighbors (0 )
39- def default_loss (xs , ys ):
55+ def default_loss (
56+ xs : Point ,
57+ ys : Union [Tuple [Iterable [Float ], Iterable [Float ]], Point ],
58+ ) -> float :
4059 """Calculate loss on a single interval.
4160
4261 Currently returns the rescaled length of the interval. If one of the
4362 y-values is missing, returns 0 (so the intervals with missing data are
4463 never touched. This behavior should be improved later.
4564 """
4665 dx = xs [1 ] - xs [0 ]
47- if isinstance (ys [0 ], Iterable ):
48- dy = [abs (a - b ) for a , b in zip (* ys )]
49- return np .hypot (dx , dy ).max ()
66+ if isinstance (ys [0 ], collections . abc . Iterable ):
67+ dy_vec = [abs (a - b ) for a , b in zip (* ys )]
68+ return np .hypot (dx , dy_vec ).max ()
5069 else :
5170 dy = ys [1 ] - ys [0 ]
5271 return np .hypot (dx , dy )
@@ -60,15 +79,21 @@ def abs_min_log_loss(xs, ys):
6079
6180
6281@uses_nth_neighbors (1 )
63- def triangle_loss (xs , ys ):
82+ def triangle_loss (
83+ xs : Sequence [Optional [Float ]],
84+ ys : Union [
85+ Iterable [Optional [Float ]],
86+ Iterable [Union [Iterable [Float ], None ]],
87+ ],
88+ ) -> float :
6489 xs = [x for x in xs if x is not None ]
6590 ys = [y for y in ys if y is not None ]
6691
6792 if len (xs ) == 2 : # we do not have enough points for a triangle
6893 return xs [1 ] - xs [0 ]
6994
7095 N = len (xs ) - 2 # number of constructed triangles
71- if isinstance (ys [0 ], Iterable ):
96+ if isinstance (ys [0 ], collections . abc . Iterable ):
7297 pts = [(x , * y ) for x , y in zip (xs , ys )]
7398 vol = simplex_volume_in_embedding
7499 else :
@@ -114,7 +139,9 @@ def resolution_loss(xs, ys):
114139 return resolution_loss
115140
116141
117- def curvature_loss_function (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
142+ def curvature_loss_function (
143+ area_factor : float = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
144+ ) -> Callable :
118145 # XXX: add a doc-string
119146 @uses_nth_neighbors (1 )
120147 def curvature_loss (xs , ys ):
@@ -133,7 +160,7 @@ def curvature_loss(xs, ys):
133160 return curvature_loss
134161
135162
136- def linspace (x_left , x_right , n ) :
163+ def linspace (x_left : float , x_right : float , n : int ) -> List [ float ] :
137164 """This is equivalent to
138165 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
139166 but it is 15-30 times faster for small 'n'."""
@@ -145,17 +172,17 @@ def linspace(x_left, x_right, n):
145172 return [x_left + step * i for i in range (1 , n )]
146173
147174
148- def _get_neighbors_from_list (xs ) :
175+ def _get_neighbors_from_list (xs : np . ndarray ) -> SortedDict :
149176 xs = np .sort (xs )
150177 xs_left = np .roll (xs , 1 ).tolist ()
151178 xs_right = np .roll (xs , - 1 ).tolist ()
152179 xs_left [0 ] = None
153180 xs_right [- 1 ] = None
154181 neighbors = {x : [x_L , x_R ] for x , x_L , x_R in zip (xs , xs_left , xs_right )}
155- return sortedcontainers . SortedDict (neighbors )
182+ return SortedDict (neighbors )
156183
157184
158- def _get_intervals (x , neighbors , nth_neighbors ) :
185+ def _get_intervals (x : float , neighbors : SortedDict , nth_neighbors : int ) -> Any :
159186 nn = nth_neighbors
160187 i = neighbors .index (x )
161188 start = max (0 , i - nn - 1 )
@@ -208,8 +235,13 @@ class Learner1D(BaseLearner):
208235 decorator for more information.
209236 """
210237
211- def __init__ (self , function , bounds , loss_per_interval = None ):
212- self .function = function
238+ def __init__ (
239+ self ,
240+ function : Callable ,
241+ bounds : Tuple [float , float ],
242+ loss_per_interval : Optional [Callable ] = None ,
243+ ) -> None :
244+ self .function = function # type: ignore
213245
214246 if hasattr (loss_per_interval , "nth_neighbors" ):
215247 self .nth_neighbors = loss_per_interval .nth_neighbors
@@ -228,8 +260,8 @@ def __init__(self, function, bounds, loss_per_interval=None):
228260
229261 # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
230262 # properties.
231- self .neighbors = sortedcontainers . SortedDict ()
232- self .neighbors_combined = sortedcontainers . SortedDict ()
263+ self .neighbors = SortedDict ()
264+ self .neighbors_combined = SortedDict ()
233265
234266 # Bounding box [[minx, maxx], [miny, maxy]].
235267 self ._bbox = [list (bounds ), [np .inf , - np .inf ]]
@@ -247,10 +279,10 @@ def __init__(self, function, bounds, loss_per_interval=None):
247279
248280 self .bounds = list (bounds )
249281
250- self ._vdim = None
282+ self ._vdim : Optional [ int ] = None
251283
252284 @property
253- def vdim (self ):
285+ def vdim (self ) -> int :
254286 """Length of the output of ``learner.function``.
255287 If the output is unsized (when it's a scalar)
256288 then `vdim = 1`.
@@ -275,35 +307,37 @@ def to_numpy(self):
275307 return np .array ([(x , * np .atleast_1d (y )) for x , y in sorted (self .data .items ())])
276308
277309 @property
278- def npoints (self ):
310+ def npoints (self ) -> int :
279311 """Number of evaluated points."""
280312 return len (self .data )
281313
282314 @cache_latest
283- def loss (self , real = True ):
315+ def loss (self , real : bool = True ) -> float :
284316 losses = self .losses if real else self .losses_combined
285317 if not losses :
286318 return np .inf
287319 max_interval , max_loss = losses .peekitem (0 )
288320 return max_loss
289321
290- def _scale_x (self , x ) :
322+ def _scale_x (self , x : Optional [ float ]) -> Optional [ float ] :
291323 if x is None :
292324 return None
293325 return x / self ._scale [0 ]
294326
295- def _scale_y (self , y ):
327+ def _scale_y (
328+ self , y : Optional [Union [Float , np .ndarray ]]
329+ ) -> Optional [Union [Float , np .ndarray ]]:
296330 if y is None :
297331 return None
298332 y_scale = self ._scale [1 ] or 1
299333 return y / y_scale
300334
301- def _get_point_by_index (self , ind ) :
335+ def _get_point_by_index (self , ind : int ) -> Optional [ float ] :
302336 if ind < 0 or ind >= len (self .neighbors ):
303337 return None
304338 return self .neighbors .keys ()[ind ]
305339
306- def _get_loss_in_interval (self , x_left , x_right ) :
340+ def _get_loss_in_interval (self , x_left : float , x_right : float ) -> float :
307341 assert x_left is not None and x_right is not None
308342
309343 if x_right - x_left < self ._dx_eps :
@@ -323,7 +357,9 @@ def _get_loss_in_interval(self, x_left, x_right):
323357 # we need to compute the loss for this interval
324358 return self .loss_per_interval (xs_scaled , ys_scaled )
325359
326- def _update_interpolated_loss_in_interval (self , x_left , x_right ):
360+ def _update_interpolated_loss_in_interval (
361+ self , x_left : float , x_right : float
362+ ) -> None :
327363 if x_left is None or x_right is None :
328364 return
329365
@@ -339,7 +375,7 @@ def _update_interpolated_loss_in_interval(self, x_left, x_right):
339375 self .losses_combined [a , b ] = (b - a ) * loss / dx
340376 a = b
341377
342- def _update_losses (self , x , real = True ):
378+ def _update_losses (self , x : float , real : bool = True ) -> None :
343379 """Update all losses that depend on x"""
344380 # When we add a new point x, we should update the losses
345381 # (x_left, x_right) are the "real" neighbors of 'x'.
@@ -382,7 +418,7 @@ def _update_losses(self, x, real=True):
382418 self .losses_combined [x , b ] = float ("inf" )
383419
384420 @staticmethod
385- def _find_neighbors (x , neighbors ) :
421+ def _find_neighbors (x : float , neighbors : SortedDict ) -> Any :
386422 if x in neighbors :
387423 return neighbors [x ]
388424 pos = neighbors .bisect_left (x )
@@ -391,14 +427,14 @@ def _find_neighbors(x, neighbors):
391427 x_right = keys [pos ] if pos != len (neighbors ) else None
392428 return x_left , x_right
393429
394- def _update_neighbors (self , x , neighbors ) :
430+ def _update_neighbors (self , x : float , neighbors : SortedDict ) -> None :
395431 if x not in neighbors : # The point is new
396432 x_left , x_right = self ._find_neighbors (x , neighbors )
397433 neighbors [x ] = [x_left , x_right ]
398434 neighbors .get (x_left , [None , None ])[1 ] = x
399435 neighbors .get (x_right , [None , None ])[0 ] = x
400436
401- def _update_scale (self , x , y ) :
437+ def _update_scale (self , x : float , y : Union [ Float , np . ndarray ]) -> None :
402438 """Update the scale with which the x and y-values are scaled.
403439
404440 For a learner where the function returns a single scalar the scale
@@ -425,7 +461,9 @@ def _update_scale(self, x, y):
425461 self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
426462 self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
427463
428- def tell (self , x , y ):
464+ def tell (
465+ self , x : float , y : Union [Float , Sequence [numbers .Number ], np .ndarray ]
466+ ) -> None :
429467 if x in self .data :
430468 # The point is already evaluated before
431469 return
@@ -460,15 +498,15 @@ def tell(self, x, y):
460498
461499 self ._oldscale = deepcopy (self ._scale )
462500
463- def tell_pending (self , x ) :
501+ def tell_pending (self , x : float ) -> None :
464502 if x in self .data :
465503 # The point is already evaluated before
466504 return
467505 self .pending_points .add (x )
468506 self ._update_neighbors (x , self .neighbors_combined )
469507 self ._update_losses (x , real = False )
470508
471- def tell_many (self , xs , ys , * , force = False ):
509+ def tell_many (self , xs : Sequence [ float ] , ys : Sequence [ Any ] , * , force = False ) -> None :
472510 if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
473511 # Only run this more efficient method if there are
474512 # at least 2 points and the amount of points added are
@@ -536,7 +574,7 @@ def tell_many(self, xs, ys, *, force=False):
536574 # have an inf loss.
537575 self ._update_interpolated_loss_in_interval (* ival )
538576
539- def ask (self , n , tell_pending = True ):
577+ def ask (self , n : int , tell_pending : bool = True ) -> Any :
540578 """Return 'n' points that are expected to maximally reduce the loss."""
541579 points , loss_improvements = self ._ask_points_without_adding (n )
542580
@@ -546,7 +584,7 @@ def ask(self, n, tell_pending=True):
546584
547585 return points , loss_improvements
548586
549- def _ask_points_without_adding (self , n ) :
587+ def _ask_points_without_adding (self , n : int ) -> Any :
550588 """Return 'n' points that are expected to maximally reduce the loss.
551589 Without altering the state of the learner"""
552590 # Find out how to divide the n points over the intervals
@@ -573,7 +611,8 @@ def _ask_points_without_adding(self, n):
573611 # Add bound intervals to quals if bounds were missing.
574612 if len (self .data ) + len (self .pending_points ) == 0 :
575613 # We don't have any points, so return a linspace with 'n' points.
576- return np .linspace (* self .bounds , n ).tolist (), [np .inf ] * n
614+ a , b = self .bounds
615+ return np .linspace (a , b , n ).tolist (), [np .inf ] * n
577616
578617 quals = loss_manager (self ._scale [0 ])
579618 if len (missing_bounds ) > 0 :
@@ -609,7 +648,7 @@ def _ask_points_without_adding(self, n):
609648 quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
610649
611650 points = list (
612- itertools .chain .from_iterable (linspace (* ival , n ) for (* ival , n ) in quals )
651+ itertools .chain .from_iterable (linspace (a , b , n ) for (( a , b ) , n ) in quals )
613652 )
614653
615654 loss_improvements = list (
@@ -624,11 +663,11 @@ def _ask_points_without_adding(self, n):
624663
625664 return points , loss_improvements
626665
627- def _loss (self , mapping , ival ) :
666+ def _loss (self , mapping : ItemSortedDict , ival : Any ) -> Any :
628667 loss = mapping [ival ]
629668 return finite_loss (ival , loss , self ._scale [0 ])
630669
631- def plot (self , * , scatter_or_line = "scatter" ):
670+ def plot (self , * , scatter_or_line : Literal [ "scatter" , "line" ] = "scatter" ):
632671 """Returns a plot of the evaluated data.
633672
634673 Parameters
@@ -663,17 +702,18 @@ def plot(self, *, scatter_or_line="scatter"):
663702
664703 return p .redim (x = dict (range = plot_bounds ))
665704
666- def remove_unfinished (self ):
705+ def remove_unfinished (self ) -> None :
667706 self .pending_points = set ()
668707 self .losses_combined = deepcopy (self .losses )
669708 self .neighbors_combined = deepcopy (self .neighbors )
670709
671- def _get_data (self ):
710+ def _get_data (self ) -> Dict [ float , float ] :
672711 return self .data
673712
674- def _set_data (self , data ) :
713+ def _set_data (self , data : Dict [ float , float ]) -> None :
675714 if data :
676- self .tell_many (* zip (* data .items ()))
715+ xs , ys = zip (* data .items ())
716+ self .tell_many (xs , ys )
677717
678718 def __getstate__ (self ):
679719 return (
@@ -694,16 +734,16 @@ def __setstate__(self, state):
694734 self .losses_combined .update (losses_combined )
695735
696736
697- def loss_manager (x_scale ) :
737+ def loss_manager (x_scale : float ) -> ItemSortedDict :
698738 def sort_key (ival , loss ):
699739 loss , ival = finite_loss (ival , loss , x_scale )
700740 return - loss , ival
701741
702- sorted_dict = sortedcollections . ItemSortedDict (sort_key )
742+ sorted_dict = ItemSortedDict (sort_key )
703743 return sorted_dict
704744
705745
706- def finite_loss (ival , loss , x_scale ) :
746+ def finite_loss (ival : Any , loss : float , x_scale : float ) -> Any :
707747 """Get the socalled finite_loss of an interval in order to be able to
708748 sort intervals that have infinite loss."""
709749 # If the loss is infinite we return the
0 commit comments