diff --git a/src/corner/__init__.py b/src/corner/__init__.py index 2e8b2c4..2cda223 100644 --- a/src/corner/__init__.py +++ b/src/corner/__init__.py @@ -1,7 +1,20 @@ # -*- coding: utf-8 -*- -__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"] +__all__ = [ + "corner", + "hist2d", + "quantile", + "overplot_lines", + "overplot_spans", + "overplot_points", +] -from corner.core import hist2d, overplot_lines, overplot_points, quantile +from corner.core import ( + hist2d, + overplot_lines, + overplot_points, + overplot_spans, + quantile, +) from corner.corner import corner from corner.version import version as __version__ diff --git a/src/corner/arviz_corner.py b/src/corner/arviz_corner.py index 089e992..b620739 100644 --- a/src/corner/arviz_corner.py +++ b/src/corner/arviz_corner.py @@ -58,6 +58,7 @@ def arviz_corner( title_fmt=".2f", title_kwargs=None, truths=None, + truth_uncertainties=None, truth_color="#4682b4", scale_hist=False, quantiles=None, @@ -68,6 +69,7 @@ def arviz_corner( use_math_text=False, reverse=False, labelpad=0.0, + truth_uncertainties_kwargs=None, hist_kwargs=None, # Arviz parameters group="posterior", @@ -126,6 +128,10 @@ def arviz_corner( truths = np.concatenate( [np.asarray(truths[k]).flatten() for k in var_names] ) + if isinstance(truth_uncertainties, Mapping): + truth_uncertainties = np.concatenate( + [np.asarray(truth_uncertainties[k]).flatten() for k in var_names] + ) if isinstance(titles, Mapping): titles = np.concatenate( [np.asarray(titles[k]).flatten() for k in var_names] @@ -150,6 +156,7 @@ def arviz_corner( title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths, + truth_uncertainties=truth_uncertainties, truth_color=truth_color, scale_hist=scale_hist, quantiles=quantiles, @@ -160,6 +167,7 @@ def arviz_corner( use_math_text=use_math_text, reverse=reverse, labelpad=labelpad, + truth_uncertainties_kwargs=truth_uncertainties_kwargs, hist_kwargs=hist_kwargs, **hist2d_kwargs, ) diff --git a/src/corner/core.py b/src/corner/core.py index 1d8d5f0..937a04d 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -46,6 +46,7 @@ def corner_impl( title_fmt=".2f", title_kwargs=None, truths=None, + truth_uncertainties=None, truth_color="#4682b4", scale_hist=False, quantiles=None, @@ -57,6 +58,7 @@ def corner_impl( use_math_text=False, reverse=False, labelpad=0.0, + truth_uncertainties_kwargs=None, hist_kwargs=None, **hist2d_kwargs, ): @@ -463,6 +465,24 @@ def corner_impl( color=truth_color, ) + if truth_uncertainties is not None: + lower_list, upper_list = _parse_truth_uncertainties( + truths, truth_uncertainties + ) + if upper_list is not None and lower_list is not None: + if truth_uncertainties_kwargs is None: + # Use default settings. + truth_uncertainties_kwargs = dict( + alpha=0.15, fc=truth_color, ec=truth_color, zorder=0 + ) + overplot_spans( + fig, + lower_list, + upper_list, + reverse=reverse, + **truth_uncertainties_kwargs, + ) + return fig @@ -853,6 +873,67 @@ def overplot_lines(fig, xs, reverse=False, **kwargs): axes[k2, k1].axhline(xs[k2], **kwargs) +def overplot_spans(fig, x1s, x2s, reverse=False, **kwargs): + """ + Overplot spans on a figure generated by ``corner.corner`` + + Parameters + ---------- + fig : Figure + The figure generated by a call to :func:`corner.corner`. + + x1s : array_like[ndim] + The start value of each span that will be plotted. This must have ``ndim`` + entries, where ``ndim`` is compatible with the :func:`corner.corner` + call that originally generated the figure. The entries can optionally + be ``None`` to omit the line in that axis. + + x2s : array_like[ndim] + The end value of each span that will be plotted. This must have ``ndim`` + entries, where ``ndim`` is compatible with the :func:`corner.corner` + call that originally generated the figure. The entries can optionally + be ``None`` to omit the line in that axis. + + reverse: bool + A boolean flag that should be set to 'True' if the corner plot itself + was plotted with 'reverse=True'. + + **kwargs + Any remaining keyword arguments are passed to the ``ax.axhspan`` + method. + + """ + K = len(x1s) + if K != len(x2s): + raise ValueError("`x1s` and `x2s` arrays must be the same length.") + + axes, _ = _get_fig_axes(fig, K) + if reverse: + for k1 in range(K): + if x1s[k1] is not None: + axes[K - k1 - 1, K - k1 - 1].axvspan( + x1s[k1], x2s[k1], **kwargs + ) + for k2 in range(k1 + 1, K): + if x1s[k1] is not None: + axes[K - k2 - 1, K - k1 - 1].axvspan( + x1s[k1], x2s[k1], **kwargs + ) + if x1s[k2] is not None: + axes[K - k2 - 1, K - k1 - 1].axhspan( + x1s[k2], x2s[k2], **kwargs + ) + else: + for k1 in range(K): + if x1s[k1] is not None: + axes[k1, k1].axvspan(x1s[k1], x2s[k1], **kwargs) + for k2 in range(k1 + 1, K): + if x1s[k1] is not None: + axes[k2, k1].axvspan(x1s[k1], x2s[k1], **kwargs) + if x1s[k2] is not None: + axes[k2, k1].axhspan(x1s[k2], x2s[k2], **kwargs) + + def overplot_points(fig, xs, reverse=False, **kwargs): """ Overplot points on a figure generated by ``corner.corner`` @@ -892,6 +973,41 @@ def overplot_points(fig, xs, reverse=False, **kwargs): axes[k2, k1].plot(xs[k1], xs[k2], **kwargs) +def _parse_truth_uncertainties(truths, truth_uncertainties): + + if truth_uncertainties is None or truths is None: + return None, None + + lowers = list() + uppers = list() + for i, current_uncert in enumerate(truth_uncertainties): + lower_uncert = None + upper_uncert = None + if current_uncert is None or truths[i] is None: + # Skip + lower_uncert = None + upper_uncert = None + elif isinstance(current_uncert, (float, np.floating)): + # Single uncertainty provided. + lower_uncert = truths[i] - current_uncert + upper_uncert = truths[i] + current_uncert + elif len(current_uncert) == 1: + # Still a single uncertainty provided but its a in a iterable. + lower_uncert = truths[i] - current_uncert[0] + upper_uncert = truths[i] + current_uncert[0] + elif len(current_uncert) == 2: + lower_uncert = truths[i] - current_uncert[0] + upper_uncert = truths[i] + current_uncert[1] + else: + raise ValueError( + f"Unexpected number of truth uncertainties provided at index {i}." + ) + lowers.append(lower_uncert) + uppers.append(upper_uncert) + + return lowers, uppers + + def _parse_input(xs): xs = np.atleast_1d(xs) if len(xs.shape) == 1: diff --git a/src/corner/corner.py b/src/corner/corner.py index e4f54ad..01f06a1 100644 --- a/src/corner/corner.py +++ b/src/corner/corner.py @@ -34,6 +34,7 @@ def corner( title_fmt=".2f", title_kwargs=None, truths=None, + truth_uncertainties=None, truth_color="#4682b4", scale_hist=False, quantiles=None, @@ -44,6 +45,7 @@ def corner( use_math_text=False, reverse=False, labelpad=0.0, + truth_uncertainties_kwargs=None, hist_kwargs=None, # Arviz parameters group="posterior", @@ -171,6 +173,13 @@ def corner( A list of reference values to indicate on the plots. Individual values can be omitted by using ``None``. + truth_uncertainties : iterable (ndim, udim = 1 or 2) + A list of uncertainties corresponding to `truths`. + If udim is 1 then that uncertainty will be used for both the + lower and upper bounds. If udim is 2 then the first value will be used + as the lower bound and the second as the upper. Individual + values can be omitted by using ``None``. + truth_color : str A ``matplotlib`` style color for the ``truths`` makers. @@ -211,6 +220,10 @@ def corner( axes yet, or ``ndim * ndim`` axes already present. If not set, the plot will be drawn on a newly created figure. + truth_uncertainties_kwargs : dict + Any extra keyword arguments to send to the axvspan used to create truth + uncertainty bands. + hist_kwargs : dict Any extra keyword arguments to send to the 1-D histogram plots. @@ -263,6 +276,7 @@ def corner( title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths, + truth_uncertainties=truth_uncertainties, truth_color=truth_color, scale_hist=scale_hist, quantiles=quantiles, @@ -273,6 +287,7 @@ def corner( use_math_text=use_math_text, reverse=reverse, labelpad=labelpad, + truth_uncertainties_kwargs=truth_uncertainties_kwargs, hist_kwargs=hist_kwargs, **hist2d_kwargs, ) @@ -295,6 +310,7 @@ def corner( title_fmt=title_fmt, title_kwargs=title_kwargs, truths=truths, + truth_uncertainties=truth_uncertainties, truth_color=truth_color, scale_hist=scale_hist, quantiles=quantiles, @@ -305,6 +321,7 @@ def corner( use_math_text=use_math_text, reverse=reverse, labelpad=labelpad, + truth_uncertainties_kwargs=truth_uncertainties_kwargs, hist_kwargs=hist_kwargs, group=group, var_names=var_names,