Source code for bambi.families.link

from collections import namedtuple

import numpy as np

from scipy import special

from bambi.utils import multilinify, indentify


def force_within_unit_interval(x):
    """Make sure data in unit interval is in (0, 1)"""
    eps = np.finfo(float).eps
    x[x == 0] = eps
    x[x == 1] = 1 - eps
    return x


def force_greater_than_zero(x):
    """Make sure data in positive reals is in (0, infty)"""
    eps = np.finfo(float).eps
    x[x == 0] = eps
    return x


def identity(x):
    return x


def cloglog(mu):
    """Cloglog function that ensures the input is greater than 0."""
    mu = force_greater_than_zero(mu)
    return np.log(-np.log(1 - mu))


def invcloglog(eta):
    """Inverse of the cloglog function that ensures result is in (0, 1)"""
    result = 1 - np.exp(-np.exp(eta))
    return force_within_unit_interval(result)


def probit(mu):
    """Probit function that ensures the input is in (0, 1)"""
    mu = force_within_unit_interval(mu)
    return 2**0.5 * special.erfinv(2 * mu - 1)  # pylint: disable=no-member


def invprobit(eta):
    """Inverse of the probit function that ensures result is in (0, 1)"""
    result = 0.5 + 0.5 * special.erf(eta / 2**0.5)  # pylint: disable=no-member
    return force_within_unit_interval(result)


def expit(eta):
    """Expit function that ensures result is in (0, 1)"""
    result = special.expit(eta)  # pylint: disable=no-member
    result = force_within_unit_interval(result)
    return result


def logit(mu):
    """Logit function that ensures the input is in (0, 1)"""
    mu = force_within_unit_interval(mu)
    return special.logit(mu)  # pylint: disable=no-member


def softmax(eta, axis=-1):
    result = special.softmax(eta, axis=axis)  # pylint: disable=no-member
    result = force_within_unit_interval(result)
    return result


def inverse_squared(mu):
    return 1 / mu**2


def inv_inverse_squared(eta):
    return 1 / np.sqrt(eta)


def arctan_2(eta):
    return 2 * np.arctan(eta)


def tan_2(mu):
    return np.tan(mu / 2)


def inverse(mu):
    return 1 / mu


def inv_inverse(eta):
    return 1 / eta


def link_not_implemented(*args, **kwargs):
    raise NotImplementedError("link not implemented")


# link: Known as g in the GLM literature. Maps the response to the linear predictor scale.
# linkinv: Known as g^(-1) in the GLM literature. Maps the linear predictor to the response scale.
LinksContainer = namedtuple("LinksContainer", ["link", "linkinv"])

LINKS = {
    "cloglog": LinksContainer(cloglog, invcloglog),
    "identity": LinksContainer(identity, identity),
    "inverse_squared": LinksContainer(inverse_squared, inv_inverse_squared),
    "inverse": LinksContainer(inverse, inv_inverse),
    "log": LinksContainer(np.log, np.exp),
    "logit": LinksContainer(logit, expit),
    "probit": LinksContainer(probit, invprobit),
    "softmax": LinksContainer(link_not_implemented, softmax),
    "tan_2": LinksContainer(tan_2, arctan_2),
}