Source code for bambi.interpret.plotting

from typing import Union

import arviz as az
from arviz.plots.backends.matplotlib import create_axes_grid
from arviz.plots.plot_utils import default_grid
import numpy as np
import pandas as pd
from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype

from bambi.models import Model
from bambi.interpret.effects import comparisons, slopes, predictions
from bambi.interpret.plot_types import plot_categoric, plot_numeric
from bambi.interpret.utils import get_covariates, ConditionalInfo
from bambi.utils import get_aliased_name, listify


def _plot_differences(
    model: Model,
    conditional_info: ConditionalInfo,
    summary_df: pd.DataFrame,
    average_by: Union[str, list] = None,
    transforms=None,
    legend: bool = True,
    ax=None,
    fig_kwargs=None,
    subplot_kwargs=None,
):
    """
    Common function used for both 'plot_comparisons' and 'plot_slopes'.
    """

    if (subplot_kwargs and not average_by) or (subplot_kwargs and average_by):
        for key, value in subplot_kwargs.items():
            conditional_info.covariates.update({key: value})
        covariates = get_covariates(conditional_info.covariates)
    elif average_by and not subplot_kwargs:
        if not isinstance(average_by, list):
            average_by = listify(average_by)
        covariate_kinds = ("main", "group", "panel")
        average_by = dict(zip(covariate_kinds, average_by))
        covariates = get_covariates(average_by)
    else:
        covariates = get_covariates(conditional_info.covariates)

    if transforms is None:
        transforms = {}

    response_name = get_aliased_name(model.response_component.response_term)

    if ax is None:
        fig_kwargs = {} if fig_kwargs is None else fig_kwargs
        panels_n = len(np.unique(summary_df[covariates.panel])) if covariates.panel else 1
        rows, cols = default_grid(panels_n)
        fig, axes = create_axes_grid(panels_n, rows, cols, backend_kwargs=fig_kwargs)
        axes = np.atleast_1d(axes)
    else:
        axes = np.atleast_1d(ax)
        if isinstance(axes[0], np.ndarray):
            fig = axes[0][0].get_figure()
        else:
            fig = axes[0].get_figure()

    if is_numeric_dtype(summary_df[covariates.main]):
        # main condition variable can be numeric but at the same time only
        # a few values, so it is treated as categoric
        if np.unique(summary_df[covariates.main]).shape[0] <= 5:
            axes = plot_categoric(covariates, summary_df, legend, axes)
        else:
            axes = plot_numeric(covariates, summary_df, transforms, legend, axes)
    elif is_categorical_dtype(summary_df[covariates.main]) or is_string_dtype(
        summary_df[covariates.main]
    ):
        axes = plot_categoric(covariates, summary_df, legend, axes)
    else:
        raise TypeError("Main covariate must be numeric or categoric.")

    response_name = get_aliased_name(model.response_component.response_term)
    for ax in axes.ravel():  # pylint: disable = redefined-argument-from-local
        ax.set(xlabel=covariates.main, ylabel=response_name)

    return fig, axes


[docs] def plot_predictions( model: Model, idata: az.InferenceData, covariates: Union[str, list], target: str = "mean", pps: bool = False, use_hdi: bool = True, prob=None, transforms=None, legend: bool = True, ax=None, fig_kwargs=None, subplot_kwargs=None, ): """Plot Conditional Adjusted Predictions Parameters ---------- model : bambi.Model The model for which we want to plot the predictions. idata : arviz.InferenceData The InferenceData object that contains the samples from the posterior distribution of the model. covariates : list or dict A sequence of between one and three names of variables in the model. target : str Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only works when pps is False as the target may not be available in the posterior predictive distribution. pps: bool, optional Whether to plot the posterior predictive samples. Defaults to ``False``. use_hdi : bool, optional Whether to compute the highest density interval (defaults to True) or the quantiles. prob : float, optional The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94. Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default. legend : bool, optional Whether to automatically include a legend in the plot. Defaults to ``True``. transforms : dict, optional Transformations that are applied to each of the variables being plotted. The keys are the name of the variables, and the values are functions to be applied. Defaults to ``None``. ax : matplotlib.axes._subplots.AxesSubplot, optional A matplotlib axes object or a sequence of them. If None, this function instantiates a new axes object. Defaults to ``None``. fig_kwargs : optional Keyword arguments passed to the matplotlib figure function as a dict. For example, ``fig_kwargs=dict(figsize=(11, 8)), sharey=True`` would make the figure 11 inches wide by 8 inches high and would share the y-axis values. subplot_kwargs : optional Keyword arguments used to determine the covariates used for the horizontal, group, and panel axes. For example, ``subplot_kwargs=dict(main="x", group="y", panel="z")`` would plot the horizontal axis as ``x``, the color (hue) as ``y``, and the panel axis as ``z``. Returns ------- matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot A tuple with the figure and the axes. Raises ------ ValueError When ``level`` is not within 0 and 1. When the main covariate is not numeric or categoric. TypeError When ``covariates`` is not a string or a list of strings. """ covariate_kinds = ("main", "group", "panel") if isinstance(covariates, dict): raise TypeError("covariates must be a string or a list of strings.") if not isinstance(covariates, dict): covariates = listify(covariates) covariates = dict(zip(covariate_kinds, covariates)) else: assert covariate_kinds[0] in covariates assert set(covariates).issubset(set(covariate_kinds)) assert 1 <= len(covariates) <= 3 if transforms is None: transforms = {} cap_data = predictions( model, idata, covariates, target=target, pps=pps, use_hdi=use_hdi, prob=prob, transforms=transforms, ) response_name = get_aliased_name(model.response_component.response_term) covariates = get_covariates(covariates) if subplot_kwargs: for key, value in subplot_kwargs.items(): setattr(covariates, key, value) if ax is None: fig_kwargs = {} if fig_kwargs is None else fig_kwargs panels_n = len(np.unique(cap_data[covariates.panel])) if covariates.panel else 1 rows, cols = default_grid(panels_n) fig, axes = create_axes_grid(panels_n, rows, cols, backend_kwargs=fig_kwargs) axes = np.atleast_1d(axes) else: axes = np.atleast_1d(ax) if isinstance(axes[0], np.ndarray): fig = axes[0][0].get_figure() else: fig = axes[0].get_figure() if is_numeric_dtype(cap_data[covariates.main]): axes = plot_numeric(covariates, cap_data, transforms, legend, axes) elif is_categorical_dtype(cap_data[covariates.main]) or is_string_dtype( cap_data[covariates.main] ): axes = plot_categoric(covariates, cap_data, legend, axes) else: raise ValueError("Main covariate must be numeric or categoric.") ylabel = response_name if target == "mean" else target for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local ax.set(xlabel=covariates.main, ylabel=ylabel) return fig, axes
[docs] def plot_comparisons( model: Model, idata: az.InferenceData, contrast: Union[str, dict, list], conditional: Union[str, dict, list, None] = None, average_by: Union[str, list] = None, comparison_type: str = "diff", use_hdi: bool = True, prob=None, legend: bool = True, transforms=None, ax=None, fig_kwargs=None, subplot_kwargs=None, ): """Plot Conditional Adjusted Comparisons Parameters ---------- model : bambi.Model The model for which we want to plot the predictions. idata : arviz.InferenceData The InferenceData object that contains the samples from the posterior distribution of the model. contrast : str, dict, list The predictor name whose contrast we would like to compare. conditional : str, dict, list The covariates we would like to condition on. average_by: str, list, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. Defaults to ``None``. comparison_type : str, optional The type of comparison to plot. Defaults to 'diff'. use_hdi : bool, optional Whether to compute the highest density interval (defaults to True) or the quantiles. prob : float, optional The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94. Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default. legend : bool, optional Whether to automatically include a legend in the plot. Defaults to ``True``. transforms : dict, optional Transformations that are applied to each of the variables being plotted. The keys are the name of the variables, and the values are functions to be applied. Defaults to ``None``. ax : matplotlib.axes._subplots.AxesSubplot, optional A matplotlib axes object or a sequence of them. If None, this function instantiates a new axes object. Defaults to ``None``. fig_kwargs : optional Keyword arguments passed to the matplotlib figure function as a dict. For example, ``fig_kwargs=dict(figsize=(11, 8)), sharey=True`` would make the figure 11 inches wide by 8 inches high and would share the y-axis values. subplot_kwargs : optional Keyword arguments used to determine the covariates used for the horizontal, group, and panel axes. For example, ``subplot_kwargs=dict(main="x", group="y", panel="z")`` would plot the horizontal axis as ``x``, the color (hue) as ``y``, and the panel axis as ``z``. Returns ------- matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot A tuple with the figure and the axes. Raises ------ ValueError If ``conditional`` and ``average_by`` are both ``None``. If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``. Warning If length of ``contrast`` is greater than 2. """ contrast_name = contrast if isinstance(contrast, dict): contrast_name, contrast_levels = next(iter(contrast.items())) if len(contrast_levels) > 2 and average_by is None: raise ValueError( "When plotting with more than 2 values for 'contrast', you must " "pass a covariate to 'average_by'. " f"{contrast_name} has {len(contrast_levels)} values." ) if not isinstance(contrast, dict): if is_categorical_dtype(model.data[contrast_name]) or is_string_dtype( model.data[contrast_name] ): contrast_levels = len(model.data[contrast_name].unique()) if contrast_levels > 2 and average_by is None: raise ValueError( "When plotting with more than 2 values for 'contrast', you must " f"pass a covariate to 'average_by'. {contrast_name} has " f"{contrast_levels} values." ) if conditional is None and average_by is None: raise ValueError("Must specify at least one of 'conditional' or 'average_by'.") if conditional is not None: if not isinstance(conditional, str): if len(conditional) > 3 and average_by is None: raise ValueError( "Must specify a covariate to 'average_by' when number of covariates" "passed to 'conditional' is greater than 3." ) if average_by is True: raise ValueError( "Plotting when 'average_by = True' is not possible as 'True' marginalizes " "over all covariates resulting in a single comparison estimate. " "Please specify a covariate(s) to 'average_by'." ) conditional_info = ConditionalInfo(model, conditional) contrast_summary = comparisons( model=model, idata=idata, contrast=contrast, conditional=conditional, average_by=average_by, comparison_type=comparison_type, use_hdi=use_hdi, prob=prob, transforms=transforms, ) return _plot_differences( model=model, conditional_info=conditional_info, summary_df=contrast_summary, average_by=average_by, transforms=transforms, legend=legend, ax=ax, fig_kwargs=fig_kwargs, subplot_kwargs=subplot_kwargs, )
[docs] def plot_slopes( model: Model, idata: az.InferenceData, wrt: Union[str, dict], conditional: Union[str, dict, list, None] = None, average_by: Union[str, list] = None, eps: float = 1e-4, slope: str = "dydx", use_hdi: bool = True, prob=None, transforms=None, legend: bool = True, ax=None, fig_kwargs=None, subplot_kwargs=None, ): """Plot Conditional Adjusted Slopes Parameters ---------- model : bambi.Model The model for which we want to plot the predictions. idata : arviz.InferenceData The InferenceData object that contains the samples from the posterior distribution of the model. wrt : str, dict The slope of the regression with respect to (wrt) this predictor will be computed. If 'wrt' is numeric, the derivative is computed, else if string or categorical, 'comparisons' is called to compute difference in group means. conditional : str, dict, list The covariates we would like to condition on. average_by: str, list, bool, optional The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. If True, it averages over all covariates in the model to obtain the average estimate. Defaults to ``None``. eps : float, optional To compute the slope, 'wrt' is evaluated at wrt +/- 'eps'. The rate of change is then computed as the difference between the two values divided by 'eps'. Defaults to 1e-4. slope: str, optional The type of slope to compute. Defaults to 'dydx'. 'dydx' represents a unit increase in 'wrt' is associated with an n-unit change in the response. 'eyex' represents a percentage increase in 'wrt' is associated with an n-percent change in the response. 'eydx' represents a unit increase in 'wrt' is associated with an n-percent change in the response. 'dyex' represents a percent change in 'wrt' is associated with a unit increase in the response. use_hdi : bool, optional Whether to compute the highest density interval (defaults to True) or the quantiles. prob : float, optional The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94. Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default. transforms : dict, optional Transformations that are applied to each of the variables being plotted. The keys are the name of the variables, and the values are functions to be applied. Defaults to ``None``. legend : bool, optional Whether to automatically include a legend in the plot. Defaults to ``True``. ax : matplotlib.axes._subplots.AxesSubplot, optional A matplotlib axes object or a sequence of them. If None, this function instantiates a new axes object. Defaults to ``None``. fig_kwargs : optional Keyword arguments passed to the matplotlib figure function as a dict. For example, ``fig_kwargs=dict(figsize=(11, 8)), sharey=True`` would make the figure 11 inches wide by 8 inches high and would share the y-axis values. subplot_kwargs : optional Keyword arguments used to determine the covariates used for the horizontal, group, and panel axes. For example, ``subplot_kwargs=dict(main="x", group="y", panel="z")`` would plot the horizontal axis as ``x``, the color (hue) as ``y``, and the panel axis as ``z``. Returns ------- matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot A tuple with the figure and the axes. Raises ------ ValueError If number of values passed with ``conditional`` is >= 2 and ``average_by`` are both ``None``. If ``conditional`` and ``average_by`` are both ``None``. If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``. If ``slope`` is not one of ('dydx', 'dyex', 'eyex', 'eydx'). """ wrt_name = wrt if isinstance(wrt, dict): wrt_name, wrt_value = next(iter(wrt.items())) if not isinstance(wrt_value, (list, np.ndarray)): wrt_value = [wrt_value] if len(wrt_value) > 2 and average_by is None: raise ValueError( "When plotting with more than 2 values for 'wrt', you must " "pass a covariate to 'average_by'" ) if not isinstance(wrt, dict): if is_categorical_dtype(model.data[wrt_name]) or is_string_dtype(model.data[wrt_name]): num_values = len(model.data[wrt_name].unique()) if num_values > 2 and average_by is None: raise ValueError( "When plotting with more than 2 values for 'wrt', you must " f"pass a covariate to 'average_by'. {wrt_name} has {num_values} values." ) if conditional is None and average_by is None: raise ValueError("Must specify at least one of 'conditional' or 'average_by'.") if conditional is not None: if not isinstance(conditional, str): if len(conditional) > 3 and average_by is None: raise ValueError( "Must specify a covariate to 'average_by' when number of covariates" "passed to 'conditional' is greater than 3." ) if average_by is True: raise ValueError( "Plotting when 'average_by = True' is not possible as 'True' marginalizes " "over all covariates resulting in a single slope estimate. " "Please specify a covariate(s) to 'average_by'." ) if slope not in ("dydx", "dyex", "eyex", "eydx"): raise ValueError("'slope' must be one of ('dydx', 'dyex', 'eyex', 'eydx')") conditional_info = ConditionalInfo(model, conditional) slopes_summary = slopes( model=model, idata=idata, wrt=wrt, conditional=conditional, average_by=average_by, eps=eps, slope=slope, use_hdi=use_hdi, prob=prob, transforms=transforms, ) return _plot_differences( model=model, conditional_info=conditional_info, summary_df=slopes_summary, average_by=average_by, transforms=transforms, legend=legend, ax=ax, fig_kwargs=fig_kwargs, subplot_kwargs=subplot_kwargs, )