11import collections .abc
22import itertools
33import math
4- import numbers
54from copy import deepcopy
6- from typing import (
7- Any ,
8- Callable ,
9- Dict ,
10- Iterable ,
11- List ,
12- Literal ,
13- Optional ,
14- Sequence ,
15- Tuple ,
16- Union ,
17- )
5+ from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , Union
186
197import cloudpickle
208import numpy as np
2513from adaptive .learner .learnerND import volume
2614from adaptive .learner .triangulation import simplex_volume_in_embedding
2715from adaptive .notebook_integration import ensure_holoviews
28- from adaptive .types import Float
16+ from adaptive .types import Float , Int , Real
2917from adaptive .utils import cache_latest
3018
31- Point = Tuple [Float , Float ]
19+ # -- types --
20+
21+ # Commonly used types
22+ Interval = Union [Tuple [float , float ], Tuple [float , float , int ]]
23+ NeighborsType = Dict [float , List [Union [float , None ]]]
24+
25+ # Types for loss_per_interval functions
26+ NoneFloat = Union [Float , None ]
27+ NoneArray = Union [np .ndarray , None ]
28+ XsType0 = Tuple [Float , Float ]
29+ YsType0 = Union [Tuple [Float , Float ], Tuple [np .ndarray , np .ndarray ]]
30+ XsType1 = Tuple [NoneFloat , NoneFloat , NoneFloat , NoneFloat ]
31+ YsType1 = Union [
32+ Tuple [NoneFloat , NoneFloat , NoneFloat , NoneFloat ],
33+ Tuple [NoneArray , NoneArray , NoneArray , NoneArray ],
34+ ]
35+ XsTypeN = Tuple [NoneFloat , ...]
36+ YsTypeN = Union [Tuple [NoneFloat , ...], Tuple [NoneArray , ...]]
37+
38+
39+ __all__ = [
40+ "uniform_loss" ,
41+ "default_loss" ,
42+ "abs_min_log_loss" ,
43+ "triangle_loss" ,
44+ "resolution_loss_function" ,
45+ "curvature_loss_function" ,
46+ "Learner1D" ,
47+ ]
3248
3349
3450@uses_nth_neighbors (0 )
35- def uniform_loss (xs : Point , ys : Any ) -> Float :
51+ def uniform_loss (xs : XsType0 , ys : YsType0 ) -> Float :
3652 """Loss function that samples the domain uniformly.
3753
3854 Works with `~adaptive.Learner1D` only.
@@ -52,10 +68,7 @@ def uniform_loss(xs: Point, ys: Any) -> Float:
5268
5369
5470@uses_nth_neighbors (0 )
55- def default_loss (
56- xs : Point ,
57- ys : Union [Tuple [Iterable [Float ], Iterable [Float ]], Point ],
58- ) -> float :
71+ def default_loss (xs : XsType0 , ys : YsType0 ) -> Float :
5972 """Calculate loss on a single interval.
6073
6174 Currently returns the rescaled length of the interval. If one of the
@@ -64,28 +77,23 @@ def default_loss(
6477 """
6578 dx = xs [1 ] - xs [0 ]
6679 if isinstance (ys [0 ], collections .abc .Iterable ):
67- dy_vec = [abs (a - b ) for a , b in zip (* ys )]
80+ dy_vec = np . array ( [abs (a - b ) for a , b in zip (* ys )])
6881 return np .hypot (dx , dy_vec ).max ()
6982 else :
7083 dy = ys [1 ] - ys [0 ]
7184 return np .hypot (dx , dy )
7285
7386
7487@uses_nth_neighbors (0 )
75- def abs_min_log_loss (xs , ys ) :
88+ def abs_min_log_loss (xs : XsType0 , ys : YsType0 ) -> Float :
7689 """Calculate loss of a single interval that prioritizes the absolute minimum."""
77- ys = [ np .log (np .abs (y ).min ()) for y in ys ]
90+ ys = tuple ( np .log (np .abs (y ).min ()) for y in ys )
7891 return default_loss (xs , ys )
7992
8093
8194@uses_nth_neighbors (1 )
82- def triangle_loss (
83- xs : Sequence [Optional [Float ]],
84- ys : Union [
85- Iterable [Optional [Float ]],
86- Iterable [Union [Iterable [Float ], None ]],
87- ],
88- ) -> float :
95+ def triangle_loss (xs : XsType1 , ys : YsType1 ) -> Float :
96+ assert len (xs ) == 4
8997 xs = [x for x in xs if x is not None ]
9098 ys = [y for y in ys if y is not None ]
9199
@@ -102,7 +110,9 @@ def triangle_loss(
102110 return sum (vol (pts [i : i + 3 ]) for i in range (N )) / N
103111
104112
105- def resolution_loss_function (min_length = 0 , max_length = 1 ):
113+ def resolution_loss_function (
114+ min_length : Real = 0 , max_length : Real = 1
115+ ) -> Callable [[XsType0 , YsType0 ], Float ]:
106116 """Loss function that is similar to the `default_loss` function, but you
107117 can set the maximum and minimum size of an interval.
108118
@@ -125,7 +135,7 @@ def resolution_loss_function(min_length=0, max_length=1):
125135 """
126136
127137 @uses_nth_neighbors (0 )
128- def resolution_loss (xs , ys ) :
138+ def resolution_loss (xs : XsType0 , ys : YsType0 ) -> Float :
129139 loss = uniform_loss (xs , ys )
130140 if loss < min_length :
131141 # Return zero such that this interval won't be chosen again
@@ -140,11 +150,11 @@ def resolution_loss(xs, ys):
140150
141151
142152def curvature_loss_function (
143- area_factor : float = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
144- ) -> Callable :
153+ area_factor : Real = 1 , euclid_factor : Real = 0.02 , horizontal_factor : Real = 0.02
154+ ) -> Callable [[ XsType1 , YsType1 ], Float ] :
145155 # XXX: add a doc-string
146156 @uses_nth_neighbors (1 )
147- def curvature_loss (xs , ys ) :
157+ def curvature_loss (xs : XsType1 , ys : YsType1 ) -> Float :
148158 xs_middle = xs [1 :3 ]
149159 ys_middle = ys [1 :3 ]
150160
@@ -160,7 +170,7 @@ def curvature_loss(xs, ys):
160170 return curvature_loss
161171
162172
163- def linspace (x_left : float , x_right : float , n : int ) -> List [float ]:
173+ def linspace (x_left : Real , x_right : Real , n : Int ) -> List [Float ]:
164174 """This is equivalent to
165175 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
166176 but it is 15-30 times faster for small 'n'."""
@@ -172,7 +182,7 @@ def linspace(x_left: float, x_right: float, n: int) -> List[float]:
172182 return [x_left + step * i for i in range (1 , n )]
173183
174184
175- def _get_neighbors_from_list (xs : np .ndarray ) -> SortedDict :
185+ def _get_neighbors_from_array (xs : np .ndarray ) -> NeighborsType :
176186 xs = np .sort (xs )
177187 xs_left = np .roll (xs , 1 ).tolist ()
178188 xs_right = np .roll (xs , - 1 ).tolist ()
@@ -182,7 +192,9 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
182192 return SortedDict (neighbors )
183193
184194
185- def _get_intervals (x : float , neighbors : SortedDict , nth_neighbors : int ) -> Any :
195+ def _get_intervals (
196+ x : float , neighbors : NeighborsType , nth_neighbors : int
197+ ) -> List [Tuple [float , float ]]:
186198 nn = nth_neighbors
187199 i = neighbors .index (x )
188200 start = max (0 , i - nn - 1 )
@@ -237,10 +249,10 @@ class Learner1D(BaseLearner):
237249
238250 def __init__ (
239251 self ,
240- function : Callable ,
241- bounds : Tuple [float , float ],
242- loss_per_interval : Optional [Callable ] = None ,
243- ) -> None :
252+ function : Callable [[ Real ], Union [ Float , np . ndarray ]] ,
253+ bounds : Tuple [Real , Real ],
254+ loss_per_interval : Optional [Callable [[ XsTypeN , YsTypeN ], Float ] ] = None ,
255+ ):
244256 self .function = function # type: ignore
245257
246258 if hasattr (loss_per_interval , "nth_neighbors" ):
@@ -255,13 +267,13 @@ def __init__(
255267 # the learners behavior in the tests.
256268 self ._recompute_losses_factor = 2
257269
258- self .data = {}
259- self .pending_points = set ()
270+ self .data : Dict [ Real , Real ] = {}
271+ self .pending_points : Set [ Real ] = set ()
260272
261273 # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
262274 # properties.
263- self .neighbors = SortedDict ()
264- self .neighbors_combined = SortedDict ()
275+ self .neighbors : NeighborsType = SortedDict ()
276+ self .neighbors_combined : NeighborsType = SortedDict ()
265277
266278 # Bounding box [[minx, maxx], [miny, maxy]].
267279 self ._bbox = [list (bounds ), [np .inf , - np .inf ]]
@@ -319,14 +331,14 @@ def loss(self, real: bool = True) -> float:
319331 max_interval , max_loss = losses .peekitem (0 )
320332 return max_loss
321333
322- def _scale_x (self , x : Optional [float ]) -> Optional [float ]:
334+ def _scale_x (self , x : Optional [Float ]) -> Optional [Float ]:
323335 if x is None :
324336 return None
325337 return x / self ._scale [0 ]
326338
327339 def _scale_y (
328- self , y : Optional [ Union [Float , np .ndarray ] ]
329- ) -> Optional [ Union [Float , np .ndarray ] ]:
340+ self , y : Union [Float , np .ndarray , None ]
341+ ) -> Union [Float , np .ndarray , None ]:
330342 if y is None :
331343 return None
332344 y_scale = self ._scale [1 ] or 1
@@ -418,7 +430,7 @@ def _update_losses(self, x: float, real: bool = True) -> None:
418430 self .losses_combined [x , b ] = float ("inf" )
419431
420432 @staticmethod
421- def _find_neighbors (x : float , neighbors : SortedDict ) -> Any :
433+ def _find_neighbors (x : float , neighbors : NeighborsType ) -> Any :
422434 if x in neighbors :
423435 return neighbors [x ]
424436 pos = neighbors .bisect_left (x )
@@ -427,7 +439,7 @@ def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
427439 x_right = keys [pos ] if pos != len (neighbors ) else None
428440 return x_left , x_right
429441
430- def _update_neighbors (self , x : float , neighbors : SortedDict ) -> None :
442+ def _update_neighbors (self , x : float , neighbors : NeighborsType ) -> None :
431443 if x not in neighbors : # The point is new
432444 x_left , x_right = self ._find_neighbors (x , neighbors )
433445 neighbors [x ] = [x_left , x_right ]
@@ -461,9 +473,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
461473 self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
462474 self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
463475
464- def tell (
465- self , x : float , y : Union [Float , Sequence [numbers .Number ], np .ndarray ]
466- ) -> None :
476+ def tell (self , x : float , y : Union [Float , Sequence [Float ], np .ndarray ]) -> None :
467477 if x in self .data :
468478 # The point is already evaluated before
469479 return
@@ -506,7 +516,17 @@ def tell_pending(self, x: float) -> None:
506516 self ._update_neighbors (x , self .neighbors_combined )
507517 self ._update_losses (x , real = False )
508518
509- def tell_many (self , xs : Sequence [float ], ys : Sequence [Any ], * , force = False ) -> None :
519+ def tell_many (
520+ self ,
521+ xs : Sequence [Float ],
522+ ys : Union [
523+ Sequence [Float ],
524+ Sequence [Sequence [Float ]],
525+ Sequence [np .ndarray ],
526+ ],
527+ * ,
528+ force : bool = False
529+ ) -> None :
510530 if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
511531 # Only run this more efficient method if there are
512532 # at least 2 points and the amount of points added are
@@ -526,8 +546,8 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
526546 points_combined = np .hstack ([points_pending , points ])
527547
528548 # Generate neighbors
529- self .neighbors = _get_neighbors_from_list (points )
530- self .neighbors_combined = _get_neighbors_from_list (points_combined )
549+ self .neighbors = _get_neighbors_from_array (points )
550+ self .neighbors_combined = _get_neighbors_from_array (points_combined )
531551
532552 # Update scale
533553 self ._bbox [0 ] = [points_combined .min (), points_combined .max ()]
@@ -574,7 +594,7 @@ def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> N
574594 # have an inf loss.
575595 self ._update_interpolated_loss_in_interval (* ival )
576596
577- def ask (self , n : int , tell_pending : bool = True ) -> Any :
597+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ List [ float ], List [ float ]] :
578598 """Return 'n' points that are expected to maximally reduce the loss."""
579599 points , loss_improvements = self ._ask_points_without_adding (n )
580600
@@ -584,7 +604,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
584604
585605 return points , loss_improvements
586606
587- def _ask_points_without_adding (self , n : int ) -> Any :
607+ def _ask_points_without_adding (self , n : int ) -> Tuple [ List [ float ], List [ float ]] :
588608 """Return 'n' points that are expected to maximally reduce the loss.
589609 Without altering the state of the learner"""
590610 # Find out how to divide the n points over the intervals
@@ -648,7 +668,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
648668 quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
649669
650670 points = list (
651- itertools .chain .from_iterable (linspace (a , b , n ) for (( a , b ) , n ) in quals )
671+ itertools .chain .from_iterable (linspace (* ival , n ) for (* ival , n ) in quals )
652672 )
653673
654674 loss_improvements = list (
@@ -663,11 +683,13 @@ def _ask_points_without_adding(self, n: int) -> Any:
663683
664684 return points , loss_improvements
665685
666- def _loss (self , mapping : ItemSortedDict , ival : Any ) -> Any :
686+ def _loss (
687+ self , mapping : Dict [Interval , float ], ival : Interval
688+ ) -> Tuple [float , Interval ]:
667689 loss = mapping [ival ]
668690 return finite_loss (ival , loss , self ._scale [0 ])
669691
670- def plot (self , * , scatter_or_line : Literal [ "scatter" , "line" ] = "scatter" ):
692+ def plot (self , * , scatter_or_line : str = "scatter" ):
671693 """Returns a plot of the evaluated data.
672694
673695 Parameters
@@ -734,7 +756,7 @@ def __setstate__(self, state):
734756 self .losses_combined .update (losses_combined )
735757
736758
737- def loss_manager (x_scale : float ) -> ItemSortedDict :
759+ def loss_manager (x_scale : float ) -> Dict [ Interval , float ] :
738760 def sort_key (ival , loss ):
739761 loss , ival = finite_loss (ival , loss , x_scale )
740762 return - loss , ival
@@ -743,8 +765,8 @@ def sort_key(ival, loss):
743765 return sorted_dict
744766
745767
746- def finite_loss (ival : Any , loss : float , x_scale : float ) -> Any :
747- """Get the socalled finite_loss of an interval in order to be able to
768+ def finite_loss (ival : Interval , loss : float , x_scale : float ) -> Tuple [ float , Interval ] :
769+ """Get the so-called finite_loss of an interval in order to be able to
748770 sort intervals that have infinite loss."""
749771 # If the loss is infinite we return the
750772 # distance between the two points.
0 commit comments