33import random
44from collections import OrderedDict
55from collections .abc import Iterable
6- from typing import Any , Callable , List , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
77
88import numpy as np
99import scipy .spatial
1313
1414from adaptive .learner .base_learner import BaseLearner , uses_nth_neighbors
1515from adaptive .learner .triangulation import (
16+ Point ,
17+ Simplex ,
1618 Triangulation ,
1719 circumsphere ,
1820 fast_det ,
@@ -40,7 +42,7 @@ def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
4042 return vol
4143
4244
43- def orientation (simplex ):
45+ def orientation (simplex : np . ndarray ):
4446 matrix = np .subtract (simplex [:- 1 ], simplex [- 1 ])
4547 # See https://www.jstor.org/stable/2315353
4648 sign , _logdet = np .linalg .slogdet (matrix )
@@ -339,12 +341,14 @@ def __init__(
339341
340342 self .function = func
341343 self ._tri = None
342- self ._losses = dict ()
344+ self ._losses : Dict [ Simplex , float ] = dict ()
343345
344- self ._pending_to_simplex = dict () # vertex → simplex
346+ self ._pending_to_simplex : Dict [ Point , Simplex ] = dict () # vertex → simplex
345347
346348 # triangulation of the pending points inside a specific simplex
347- self ._subtriangulations = dict () # simplex → triangulation
349+ self ._subtriangulations : Dict [
350+ Simplex , Triangulation
351+ ] = dict () # simplex → triangulation
348352
349353 # scale to unit hypercube
350354 # for the input
@@ -456,7 +460,7 @@ def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> No
456460 to_delete , to_add = tri .add_point (point , simplex , transform = self ._transform )
457461 self ._update_losses (to_delete , to_add )
458462
459- def _simplex_exists (self , simplex : Any ) -> bool : # XXX: specify simplex: Any
463+ def _simplex_exists (self , simplex : Simplex ) -> bool :
460464 simplex = tuple (sorted (simplex ))
461465 return simplex in self .tri .simplices
462466
@@ -498,7 +502,7 @@ def tell_pending(self, point: Tuple[float, ...], *, simplex=None,) -> None:
498502 self ._update_subsimplex_losses (simpl , to_add )
499503
500504 def _try_adding_pending_point_to_simplex (
501- self , point : Tuple [ float , ...], simplex : Any , # XXX: specify simplex: Any
505+ self , point : Point , simplex : Simplex ,
502506 ) -> Any :
503507 # try to insert it
504508 if not self .tri .point_in_simplex (point , simplex ):
@@ -512,8 +516,8 @@ def _try_adding_pending_point_to_simplex(
512516 return self ._subtriangulations [simplex ].add_point (point )
513517
514518 def _update_subsimplex_losses (
515- self , simplex : Any , new_subsimplices : Any
516- ) -> None : # XXX: specify simplex: Any
519+ self , simplex : Simplex , new_subsimplices : Set [ Simplex ]
520+ ) -> None :
517521 loss = self ._losses [simplex ]
518522
519523 loss_density = loss / self .tri .volume (simplex )
@@ -534,7 +538,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
534538 else :
535539 return self ._ask_and_tell_pending (n )
536540
537- def _ask_bound_point (self ,) -> Tuple [Tuple [ float , ...] , float ]:
541+ def _ask_bound_point (self ,) -> Tuple [Point , float ]:
538542 # get the next bound point that is still available
539543 new_point = next (
540544 p
@@ -544,7 +548,7 @@ def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
544548 self .tell_pending (new_point )
545549 return new_point , np .inf
546550
547- def _ask_point_without_known_simplices (self ,) -> Tuple [Tuple [ float , ...] , float ]:
551+ def _ask_point_without_known_simplices (self ,) -> Tuple [Point , float ]:
548552 assert not self ._bounds_available
549553 # pick a random point inside the bounds
550554 # XXX: change this into picking a point based on volume loss
@@ -585,7 +589,7 @@ def _pop_highest_existing_simplex(self) -> Any:
585589 " be a simplex available if LearnerND.tri() is not None."
586590 )
587591
588- def _ask_best_point (self ,) -> Tuple [Tuple [ float , ...] , float ]:
592+ def _ask_best_point (self ,) -> Tuple [Point , float ]:
589593 assert self .tri is not None
590594
591595 loss , simplex , subsimplex = self ._pop_highest_existing_simplex ()
@@ -612,7 +616,7 @@ def _bounds_available(self) -> bool:
612616 for p in self ._bounds_points
613617 )
614618
615- def _ask (self ,) -> Tuple [Tuple [ float , ...] , float ]:
619+ def _ask (self ,) -> Tuple [Point , float ]:
616620 if self ._bounds_available :
617621 return self ._ask_bound_point () # O(1)
618622
@@ -624,7 +628,7 @@ def _ask(self,) -> Tuple[Tuple[float, ...], float]:
624628
625629 return self ._ask_best_point () # O(log N)
626630
627- def _compute_loss (self , simplex : Any ) -> float : # XXX: specify simplex: Any
631+ def _compute_loss (self , simplex : Simplex ) -> float :
628632 # get the loss
629633 vertices = self .tri .get_vertices (simplex )
630634 values = [self .data [tuple (v )] for v in vertices ]
@@ -663,7 +667,7 @@ def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
663667 )
664668 )
665669
666- def _update_losses (self , to_delete : set , to_add : set ) -> None :
670+ def _update_losses (self , to_delete : Set [ Simplex ] , to_add : Set [ Simplex ] ) -> None :
667671 # XXX: add the points outside the triangulation to this as well
668672 pending_points_unbound = set ()
669673
@@ -733,13 +737,11 @@ def _recompute_all_losses(self) -> None:
733737 )
734738
735739 @property
736- def _scale (self ) -> Union [ float , np . int64 ] :
740+ def _scale (self ) -> float :
737741 # get the output scale
738742 return self ._max_value - self ._min_value
739743
740- def _update_range (
741- self , new_output : Union [List [int ], float , float , np .ndarray ]
742- ) -> bool :
744+ def _update_range (self , new_output : Union [List [int ], float , np .ndarray ]) -> bool :
743745 if self ._min_value is None or self ._max_value is None :
744746 # this is the first point, nothing to do, just set the range
745747 self ._min_value = np .min (new_output )
@@ -790,7 +792,7 @@ def remove_unfinished(self) -> None:
790792 # Plotting related stuff #
791793 ##########################
792794
793- def plot (self , n = None , tri_alpha = 0 ):
795+ def plot (self , n : Optional [ int ] = None , tri_alpha : float = 0 ):
794796 """Plot the function we want to learn, only works in 2D.
795797
796798 Parameters
@@ -851,7 +853,7 @@ def plot(self, n=None, tri_alpha=0):
851853
852854 return im .opts (style = im_opts ) * tris .opts (style = tri_opts , ** no_hover )
853855
854- def plot_slice (self , cut_mapping , n = None ):
856+ def plot_slice (self , cut_mapping : Dict [ int , float ], n : Optional [ int ] = None ):
855857 """Plot a 1D or 2D interpolated slice of a N-dimensional function.
856858
857859 Parameters
@@ -921,7 +923,7 @@ def plot_slice(self, cut_mapping, n=None):
921923 else :
922924 raise ValueError ("Only 1 or 2-dimensional plots can be generated." )
923925
924- def plot_3D (self , with_triangulation = False ):
926+ def plot_3D (self , with_triangulation : bool = False ):
925927 """Plot the learner's data in 3D using plotly.
926928
927929 Does *not* work with the
@@ -1010,7 +1012,7 @@ def _set_data(self, data: OrderedDict) -> None:
10101012 if data :
10111013 self .tell_many (* zip (* data .items ()))
10121014
1013- def _get_iso (self , level = 0.0 , which = "surface" ):
1015+ def _get_iso (self , level : float = 0.0 , which : str = "surface" ):
10141016 if which == "surface" :
10151017 if self .ndim != 3 or self .vdim != 1 :
10161018 raise Exception (
@@ -1081,7 +1083,9 @@ def _get_vertex_index(a, b):
10811083
10821084 return vertices , faces_or_lines
10831085
1084- def plot_isoline (self , level = 0.0 , n = None , tri_alpha = 0 ):
1086+ def plot_isoline (
1087+ self , level : float = 0.0 , n : Optional [int ] = None , tri_alpha : float = 0
1088+ ):
10851089 """Plot the isoline at a specific level, only works in 2D.
10861090
10871091 Parameters
@@ -1121,7 +1125,7 @@ def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
11211125 contour = contour .opts (style = contour_opts )
11221126 return plot * contour
11231127
1124- def plot_isosurface (self , level = 0.0 , hull_opacity = 0.2 ):
1128+ def plot_isosurface (self , level : float = 0.0 , hull_opacity : float = 0.2 ):
11251129 """Plots a linearly interpolated isosurface.
11261130
11271131 This is the 3D analog of an isoline. Does *not* work with the
@@ -1159,7 +1163,7 @@ def plot_isosurface(self, level=0.0, hull_opacity=0.2):
11591163 hull_mesh = self ._get_hull_mesh (opacity = hull_opacity )
11601164 return plotly .offline .iplot ([isosurface , hull_mesh ])
11611165
1162- def _get_hull_mesh (self , opacity = 0.2 ):
1166+ def _get_hull_mesh (self , opacity : float = 0.2 ):
11631167 plotly = ensure_plotly ()
11641168 hull = scipy .spatial .ConvexHull (self ._bounds_points )
11651169
0 commit comments