1+ from __future__ import annotations
2+
13import abc
24from contextlib import suppress
5+ from typing import Any , Callable
36
47import cloudpickle
58
69from adaptive .utils import _RequireAttrsABCMeta , load , save
710
811
9- def uses_nth_neighbors (n : int ):
12+ def uses_nth_neighbors (n : int ) -> Callable [[ int ], Callable [[ BaseLearner ], float ]] :
1013 """Decorator to specify how many neighboring intervals the loss function uses.
1114
1215 Wraps loss functions to indicate that they expect intervals together
@@ -53,7 +56,9 @@ def uses_nth_neighbors(n: int):
5356 ... return loss
5457 """
5558
56- def _wrapped (loss_per_interval ):
59+ def _wrapped (
60+ loss_per_interval : Callable [[BaseLearner ], float ]
61+ ) -> Callable [[BaseLearner ], float ]:
5762 loss_per_interval .nth_neighbors = n
5863 return loss_per_interval
5964
@@ -82,10 +87,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8287 """
8388
8489 data : dict
85- npoints : int
8690 pending_points : set
91+ function : Callable
92+
93+ @property
94+ @abc .abstractmethod
95+ def npoints (self ) -> int :
96+ """Number of learned points."""
8797
88- def tell (self , x , y ) :
98+ def tell (self , x : Any , y : Any ) -> None :
8999 """Tell the learner about a single value.
90100
91101 Parameters
@@ -95,7 +105,7 @@ def tell(self, x, y):
95105 """
96106 self .tell_many ([x ], [y ])
97107
98- def tell_many (self , xs , ys ) :
108+ def tell_many (self , xs : Any , ys : Any ) -> None :
99109 """Tell the learner about some values.
100110
101111 Parameters
@@ -107,16 +117,16 @@ def tell_many(self, xs, ys):
107117 self .tell (x , y )
108118
109119 @abc .abstractmethod
110- def tell_pending (self , x ) :
120+ def tell_pending (self , x : Any ) -> None :
111121 """Tell the learner that 'x' has been requested such
112122 that it's not suggested again."""
113123
114124 @abc .abstractmethod
115- def remove_unfinished (self ):
125+ def remove_unfinished (self ) -> None :
116126 """Remove uncomputed data from the learner."""
117127
118128 @abc .abstractmethod
119- def loss (self , real = True ):
129+ def loss (self , real : bool = True ) -> float :
120130 """Return the loss for the current state of the learner.
121131
122132 Parameters
@@ -128,7 +138,7 @@ def loss(self, real=True):
128138 """
129139
130140 @abc .abstractmethod
131- def ask (self , n , tell_pending = True ):
141+ def ask (self , n : int , tell_pending : bool = True ):
132142 """Choose the next 'n' points to evaluate.
133143
134144 Parameters
@@ -142,11 +152,11 @@ def ask(self, n, tell_pending=True):
142152 """
143153
144154 @abc .abstractmethod
145- def _get_data (self ):
155+ def _get_data (self ) -> Any :
146156 pass
147157
148158 @abc .abstractmethod
149- def _set_data (self ):
159+ def _set_data (self , data : Any ):
150160 pass
151161
152162 @abc .abstractmethod
@@ -164,7 +174,7 @@ def copy_from(self, other):
164174 """
165175 self ._set_data (other ._get_data ())
166176
167- def save (self , fname , compress = True ):
177+ def save (self , fname : str , compress : bool = True ) -> None :
168178 """Save the data of the learner into a pickle file.
169179
170180 Parameters
@@ -178,7 +188,7 @@ def save(self, fname, compress=True):
178188 data = self ._get_data ()
179189 save (fname , data , compress )
180190
181- def load (self , fname , compress = True ):
191+ def load (self , fname : str , compress : bool = True ) -> None :
182192 """Load the data of a learner from a pickle file.
183193
184194 Parameters
@@ -193,8 +203,8 @@ def load(self, fname, compress=True):
193203 data = load (fname , compress )
194204 self ._set_data (data )
195205
196- def __getstate__ (self ):
206+ def __getstate__ (self ) -> bytes :
197207 return cloudpickle .dumps (self .__dict__ )
198208
199- def __setstate__ (self , state ) :
209+ def __setstate__ (self , state : bytes ) -> None :
200210 self .__dict__ = cloudpickle .loads (state )
0 commit comments