Source code for bambi.priors.scaler

import numpy as np

from bambi.families.univariate import Gaussian, StudentT, VonMises

from .prior import Prior

[docs]class PriorScaler: """Scale prior distributions parameters.""" # Standard deviation multiplier. STD = 2.5 def __init__(self, model): self.model = model self.has_intercept = model.intercept_term is not None self.priors = {} # Compute mean and std of the response if isinstance(, (Gaussian, StudentT)): self.response_mean = np.mean( self.response_std = np.std( else: self.response_mean = 0 self.response_std = 1 def get_intercept_stats(self): mu = self.response_mean sigma = self.STD * self.response_std # Only adjust mu and sigma if there is at least one Normal prior for a common term. if self.priors: sigmas = np.hstack([prior["sigma"] for prior in self.priors.values()]) x_mean = np.hstack([self.model.terms[term].data.mean(axis=0) for term in self.priors]) sigma = (sigma**2 +**2, x_mean**2)) ** 0.5 return mu, sigma def get_slope_sigma(self, x): return self.STD * (self.response_std / np.std(x)) def scale_response(self): # Add cases for other families priors = if isinstance(, (Gaussian, StudentT)): if priors["sigma"].auto_scale: priors["sigma"] = Prior("HalfStudentT", nu=4, sigma=self.response_std) elif isinstance(, VonMises): if priors["kappa"].auto_scale: priors["kappa"] = Prior("HalfStudentT", nu=4, sigma=self.response_std) def scale_intercept(self, term): if != "Normal": return mu, sigma = self.get_intercept_stats() term.prior.update(mu=mu, sigma=sigma) def scale_common(self, term): if != "Normal": return # As many zeros as columns in the data. It can be greater than 1 for categorical variables mu = np.zeros([1]) sigma = np.zeros([1]) # Iterate over columns in the data for i, x in enumerate( sigma[i] = self.get_slope_sigma(x) # Save and set prior self.priors.update({ {"mu": mu, "sigma": sigma}}) term.prior.update(mu=mu, sigma=sigma) def scale_group_specific(self, term): if term.prior.args["sigma"].name != "HalfNormal": return # Handle intercepts if term.kind == "intercept": _, sigma = self.get_intercept_stats() # Handle slopes else: # Recreate the corresponding common effect data if len(term.predictor.shape) == 2: data_as_common = term.predictor else: data_as_common = term.predictor[:, None] sigma = np.zeros(data_as_common.shape[1]) for i, x in enumerate(data_as_common.T): sigma[i] = self.get_slope_sigma(x) term.prior.args["sigma"].update(sigma=np.squeeze(np.atleast_1d(sigma))) def scale(self): # Scale response self.scale_response() # Scale common terms for term in self.model.common_terms.values(): if term.prior.auto_scale: self.scale_common(term) # Scale intercept if self.has_intercept: term = self.model.intercept_term if term.prior.auto_scale: self.scale_intercept(term) # Scale group-specific terms for term in self.model.group_specific_terms.values(): if term.prior.auto_scale: self.scale_group_specific(term)