|
6 | 6 | from copy import copy, deepcopy |
7 | 7 | from numbers import Integral as Int |
8 | 8 | from numbers import Real |
9 | | -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union |
| 9 | +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Sequence, Tuple, Union |
10 | 10 |
|
11 | 11 | import cloudpickle |
12 | 12 | import numpy as np |
|
24 | 24 | partial_function_from_dataframe, |
25 | 25 | ) |
26 | 26 |
|
| 27 | +if TYPE_CHECKING: |
| 28 | + import holoviews |
| 29 | + |
27 | 30 | try: |
28 | 31 | from typing import TypeAlias |
29 | 32 | except ImportError: |
30 | 33 | # Remove this when we drop support for Python 3.9 |
31 | 34 | from typing_extensions import TypeAlias |
32 | 35 |
|
| 36 | +try: |
| 37 | + from typing import Literal |
| 38 | +except ImportError: |
| 39 | + # Remove this when we drop support for Python 3.7 |
| 40 | + from typing_extensions import Literal |
| 41 | + |
| 42 | + |
33 | 43 | try: |
34 | 44 | import pandas |
35 | 45 |
|
@@ -145,7 +155,7 @@ def resolution_loss_function( |
145 | 155 |
|
146 | 156 | Returns |
147 | 157 | ------- |
148 | | - loss_function : callable |
| 158 | + loss_function |
149 | 159 |
|
150 | 160 | Examples |
151 | 161 | -------- |
@@ -230,12 +240,12 @@ class Learner1D(BaseLearner): |
230 | 240 |
|
231 | 241 | Parameters |
232 | 242 | ---------- |
233 | | - function : callable |
| 243 | + function |
234 | 244 | The function to learn. Must take a single real parameter and |
235 | 245 | return a real number or 1D array. |
236 | | - bounds : pair of reals |
| 246 | + bounds |
237 | 247 | The bounds of the interval on which to learn 'function'. |
238 | | - loss_per_interval: callable, optional |
| 248 | + loss_per_interval |
239 | 249 | A function that returns the loss for a single interval of the domain. |
240 | 250 | If not provided, then a default is used, which uses the scaled distance |
241 | 251 | in the x-y plane as the loss. See the notes for more details. |
@@ -356,15 +366,15 @@ def to_dataframe( |
356 | 366 |
|
357 | 367 | Parameters |
358 | 368 | ---------- |
359 | | - with_default_function_args : bool, optional |
| 369 | + with_default_function_args |
360 | 370 | Include the ``learner.function``'s default arguments as a |
361 | 371 | column, by default True |
362 | | - function_prefix : str, optional |
| 372 | + function_prefix |
363 | 373 | Prefix to the ``learner.function``'s default arguments' names, |
364 | 374 | by default "function." |
365 | | - x_name : str, optional |
| 375 | + x_name |
366 | 376 | Name of the input value, by default "x" |
367 | | - y_name : str, optional |
| 377 | + y_name |
368 | 378 | Name of the output value, by default "y" |
369 | 379 |
|
370 | 380 | Returns |
@@ -403,16 +413,16 @@ def load_dataframe( |
403 | 413 |
|
404 | 414 | Parameters |
405 | 415 | ---------- |
406 | | - df : pandas.DataFrame |
| 416 | + df |
407 | 417 | The data to load. |
408 | | - with_default_function_args : bool, optional |
| 418 | + with_default_function_args |
409 | 419 | The ``with_default_function_args`` used in ``to_dataframe()``, |
410 | 420 | by default True |
411 | | - function_prefix : str, optional |
| 421 | + function_prefix |
412 | 422 | The ``function_prefix`` used in ``to_dataframe``, by default "function." |
413 | | - x_name : str, optional |
| 423 | + x_name |
414 | 424 | The ``x_name`` used in ``to_dataframe``, by default "x" |
415 | | - y_name : str, optional |
| 425 | + y_name |
416 | 426 | The ``y_name`` used in ``to_dataframe``, by default "y" |
417 | 427 | """ |
418 | 428 | self.tell_many(df[x_name].values, df[y_name].values) |
@@ -795,17 +805,19 @@ def _loss( |
795 | 805 | loss = mapping[ival] |
796 | 806 | return finite_loss(ival, loss, self._scale[0]) |
797 | 807 |
|
798 | | - def plot(self, *, scatter_or_line: str = "scatter"): |
| 808 | + def plot( |
| 809 | + self, *, scatter_or_line: Literal["scatter", "line"] = "scatter" |
| 810 | + ) -> holoviews.Overlay: |
799 | 811 | """Returns a plot of the evaluated data. |
800 | 812 |
|
801 | 813 | Parameters |
802 | 814 | ---------- |
803 | | - scatter_or_line : str, default: "scatter" |
| 815 | + scatter_or_line |
804 | 816 | Plot as a scatter plot ("scatter") or a line plot ("line"). |
805 | 817 |
|
806 | 818 | Returns |
807 | 819 | ------- |
808 | | - plot : `holoviews.Overlay` |
| 820 | + plot |
809 | 821 | Plot of the evaluated data. |
810 | 822 | """ |
811 | 823 | if scatter_or_line not in ("scatter", "line"): |
|
0 commit comments