Skip to content

Commit aa76f21

Browse files
committed
Added tests. Added note on 'how_to_slice_plot' about backends.
1 parent b7c1a97 commit aa76f21

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

docs/source/how_to/how_to_slice_plot.ipynb

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@
8383
"fig.show()"
8484
]
8585
},
86+
{
87+
"cell_type": "markdown",
88+
"metadata": {},
89+
"source": [
90+
":::{note}\n",
91+
"\n",
92+
"For details on using other plotting backends, see [How to change the plotting backend](how_to_change_plotting_backend.ipynb).\n",
93+
"\n",
94+
":::"
95+
]
96+
},
8697
{
8798
"cell_type": "markdown",
8899
"metadata": {},

tests/optimagic/visualization/test_slice_plot.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33

44
from optimagic import mark
55
from optimagic.parameters.bounds import Bounds
6-
from optimagic.visualization.slice_plot import slice_plot
6+
from optimagic.visualization.plotting_utilities import LineData, MarkerData
7+
from optimagic.visualization.slice_plot import (
8+
_extract_slice_plot_lines_and_labels,
9+
_get_plot_data,
10+
_get_processed_func_and_func_eval,
11+
slice_plot,
12+
)
713

814

915
@pytest.fixture()
@@ -40,6 +46,7 @@ def sphere(params):
4046
{"share_x": True},
4147
{"share_y": False},
4248
{"return_dict": True},
49+
{"backend": "matplotlib"},
4350
]
4451
parametrization = [
4552
(func, kwargs) for func in [sphere_loglike, sphere] for kwargs in KWARGS
@@ -53,3 +60,49 @@ def test_slice_plot(fixed_inputs, func, kwargs):
5360
**fixed_inputs,
5461
**kwargs,
5562
)
63+
64+
65+
def test_extract_slice_plot_lines(fixed_inputs):
66+
params, bounds = fixed_inputs["params"], fixed_inputs["bounds"]
67+
68+
func, func_eval = _get_processed_func_and_func_eval(
69+
sphere, func_kwargs=None, params=params
70+
)
71+
72+
plot_data, internal_params = _get_plot_data(
73+
func=func,
74+
params=params,
75+
bounds=bounds,
76+
func_eval=func_eval,
77+
selector=None,
78+
n_gridpoints=10,
79+
batch_evaluator="joblib",
80+
n_cores=1,
81+
)
82+
83+
lines_list, marker_list, xlabels, ylabels = _extract_slice_plot_lines_and_labels(
84+
plot_data=plot_data,
85+
internal_params=internal_params,
86+
func_eval=func_eval,
87+
param_names={"alpha": "Alpha"},
88+
color=None,
89+
)
90+
91+
assert isinstance(lines_list, list) and len(lines_list) == len(params)
92+
assert all(
93+
isinstance(subplot_lines, list)
94+
and len(subplot_lines) == 1
95+
and isinstance(subplot_lines[0], LineData)
96+
for subplot_lines in lines_list
97+
)
98+
99+
assert isinstance(marker_list, list) and len(marker_list) == len(params)
100+
assert all(isinstance(marker, MarkerData) for marker in marker_list)
101+
for i, k in enumerate(params):
102+
assert marker_list[i].x == params[k]
103+
104+
assert isinstance(xlabels, list)
105+
assert xlabels == ["Alpha", "beta", "gamma", "delta"]
106+
107+
assert isinstance(ylabels, list)
108+
assert all(ylabel == "Function Value" for ylabel in ylabels)

0 commit comments

Comments
 (0)