interpret.plot_predictions

interpret.plot_predictions(model, idata, conditional=None, average_by=None, target='mean', sample_new_groups=False, pps=False, use_hdi=True, prob=None, transforms=None, legend=True, ax=None, fig_kwargs=None, subplot_kwargs=None)

Plot Conditional Adjusted Predictions

Parameters

Name Type Description Default
model bambi.Model The model for which we want to plot the predictions. required
idata arviz.InferenceData The InferenceData object that contains the samples from the posterior distribution of the model. required
conditional (str, list, dict) The covariates we would like to condition on. If dict, keys are the covariate names and values are the values to condition on. None
average_by Union[str, list, None] The covariates we would like to average by. The passed covariate(s) will marginalize over the other covariates in the model. If True, it averages over all covariates in the model to obtain the average estimate. Defaults to None. None
target str Which model parameter to plot. Defaults to ‘mean’. Passing a parameter into target only works when pps is False as the target may not be available in the posterior predictive distribution. 'mean'
sample_new_groups bool If the model contains group-level effects, and data is passed for unseen groups, whether to sample from the new groups. Defaults to False. False
pps bool Whether to plot the posterior predictive samples. Defaults to False. False
use_hdi bool Whether to compute the highest density interval (defaults to True) or the quantiles. True
prob float The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94. Changing the global variable az.rcParam["stats.hdi_prob"] affects this default. None
legend bool Whether to automatically include a legend in the plot. Defaults to True. True
transforms dict Transformations that are applied to each of the variables being plotted. The keys are the name of the variables, and the values are functions to be applied. Defaults to None. None
ax matplotlib.axes._subplots.AxesSubplot A matplotlib axes object or a sequence of them. If None, this function instantiates a new axes object. Defaults to None. None
fig_kwargs optional Keyword arguments passed to the matplotlib figure function as a dict. For example, fig_kwargs=dict(figsize=(11, 8)), sharey=True would make the figure 11 inches wide by 8 inches high and would share the y-axis values. None
subplot_kwargs optional Keyword arguments used to determine the covariates used for the horizontal, group, and panel axes. For example, subplot_kwargs=dict(main="x", group="y", panel="z") would plot the horizontal axis as x, the color (hue) as y, and the panel axis as z. None

Returns

Type Description
(matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot) A tuple with the figure and the axes.

Raises

Type Description
ValueError If conditional and average_by are both None. If length of conditional is greater than 3 and average_by is None. If main covariate is not numeric or categoric.