-
-
Notifications
You must be signed in to change notification settings - Fork 455
Description
Describe the bug
When attempting to include reference_values
in plot_pair
, the keys of reference_values
are compared against the labeled var_names
, i.e. after transforming them under the MapLabeller
passed via labeller
if any. Since plot_pair
flattens var_names
and further alters the formatting via labeller.make_label_vert
, the user doesn't know to update the keys of reference_values
to match the formatted, flattened var_names
produced within plot_pair
.
To Reproduce
import arviz as az
import numpy as np
from arviz.labels import MapLabeller
datadict = {
"x": np.random.randn(1, 100, 2),
"y": np.random.randn(1, 100, 2),
}
trace = az.convert_to_inference_data(datadict)
reference_values = {
"x": np.array([-0.5, 0.5]),
"y": np.array([-0.5, 0.5]),
}
var_name_map = {
"x": "x (cm)",
"y": "y (cm)",
}
labeller = MapLabeller(var_name_map)
az.plot_pair(trace, reference_values=reference_values, labeller=labeller)
The resulting plot does not include the reference value markers and further raises the warning
UserWarning: Argument reference_values does not include reference value for: x (cm) 1, y (cm) 1, x (cm) 0, y (cm) 0
Expected behavior
The plot should contain the reference values without knowledge about how plot_pair
alters the label formatting (i.e., changing "x"
to "x (cm)\n0"
and "x (cm)\n1"
.
Additional context
arviz 0.22.0
PR incoming...