#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import warnings
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple
import gpytorch
import numpy as np
import torch
from aepsych.config import Config
from aepsych.factory.default import default_mean_covar_factory
from aepsych.models.base import AEPsychModelDeviceMixin
from aepsych.models.inducing_point_allocators import (
AutoAllocator,
DummyAllocator,
KMeansAllocator,
SobolAllocator,
)
from aepsych.models.utils import select_inducing_points
from aepsych.utils import _process_bounds, get_optimizer_options, promote_0d
from aepsych.utils_logging import getLogger
from botorch.models.utils.inducing_point_allocators import InducingPointAllocator
from gpytorch.likelihoods import BernoulliLikelihood, BetaLikelihood, Likelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
from scipy.special import owens_t
from scipy.stats import norm
from torch.distributions import Normal
logger = getLogger()
[docs]class GPClassificationModel(AEPsychModelDeviceMixin, ApproximateGP):
"""Probit-GP model with variational inference.
From a conventional ML perspective this is a GP Classification model,
though in the psychophysics context it can also be thought of as a
nonlinear generalization of the standard linear model for 1AFC or
yes/no trials.
For more on variational inference, see e.g.
https://docs.gpytorch.ai/en/v1.1.1/examples/04_Variational_and_Approximate_GPs/
"""
_batch_size = 1
_num_outputs = 1
stimuli_per_trial = 1
outcome_type = "binary"
def __init__(
self,
lb: torch.Tensor,
ub: torch.Tensor,
inducing_point_method: InducingPointAllocator,
dim: Optional[int] = None,
mean_module: Optional[gpytorch.means.Mean] = None,
covar_module: Optional[gpytorch.kernels.Kernel] = None,
likelihood: Optional[Likelihood] = None,
inducing_size: Optional[int] = None,
max_fit_time: Optional[float] = None,
optimizer_options: Optional[Dict[str, Any]] = None,
inducing_points: Optional[torch.Tensor] = None,
) -> None:
"""Initialize the GP Classification model
Args:
lb (torch.Tensor): Lower bounds of the parameters.
ub (torch.Tensor): Upper bounds of the parameters.
inducing_point_method (InducingPointAllocator): The method to use for selecting inducing points.
dim (int, optional): The number of dimensions in the parameter space. If None, it is inferred from the size
of lb and ub.
mean_module (gpytorch.means.Mean, optional): GP mean class. Defaults to a constant with a normal prior.
covar_module (gpytorch.kernels.Kernel, optional): GP covariance kernel class. Defaults to scaled RBF with a
gamma prior.
likelihood (gpytorch.likelihood.Likelihood, optional): The likelihood function to use. If None defaults to
Bernouli likelihood.
inducing_size (int, optional): Number of inducing points. Defaults to 99.
max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None,
there is no limit to the fitting time.
optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during
fitting. Assumes we are using L-BFGS-B.
"""
lb, ub, self.dim = _process_bounds(lb, ub, dim)
self.max_fit_time = max_fit_time
self.inducing_size = inducing_size or 99
self.optimizer_options = (
{"options": optimizer_options} if optimizer_options else {"options": {}}
)
if self.inducing_size >= 100:
logger.warning(
(
"inducing_size in GPClassificationModel is >=100, more inducing points "
"can lead to better fits but slower performance in general. Performance "
"at >=100 inducing points is especially slow."
)
)
if likelihood is None:
likelihood = BernoulliLikelihood()
if mean_module is None or covar_module is None:
default_mean, default_covar = default_mean_covar_factory(
dim=self.dim, stimuli_per_trial=self.stimuli_per_trial
)
self.inducing_point_method = inducing_point_method
inducing_points = select_inducing_points(
allocator=self.inducing_point_method,
inducing_size=self.inducing_size,
covar_module=covar_module or default_covar,
)
self.last_inducing_points_method = self.inducing_point_method.allocator_used
self.inducing_points = inducing_points
variational_distribution = CholeskyVariationalDistribution(
inducing_points.size(0), batch_shape=torch.Size([self._batch_size])
).to(inducing_points)
variational_strategy = VariationalStrategy(
self,
inducing_points,
variational_distribution,
learn_inducing_locations=False,
)
super().__init__(variational_strategy)
# Tensors need to be directly registered, Modules themselves can be assigned as attr
self.register_buffer("lb", lb)
self.register_buffer("ub", ub)
self.likelihood = likelihood
self.mean_module = mean_module or default_mean
self.covar_module = covar_module or default_covar
self._fresh_state_dict = deepcopy(self.state_dict())
self._fresh_likelihood_dict = deepcopy(self.likelihood.state_dict())
[docs] @classmethod
def from_config(cls, config: Config) -> GPClassificationModel:
"""Alternate constructor for GPClassification model from a configuration.
This is used when we recursively build a full sampling strategy
from a configuration. TODO: document how this works in some tutorial.
Args:
config (Config): A configuration containing keys/values matching this class
Returns:
GPClassificationModel: Configured class instance.
"""
classname = cls.__name__
inducing_size = config.getint(classname, "inducing_size", fallback=None)
lb = config.gettensor(classname, "lb")
ub = config.gettensor(classname, "ub")
dim = config.getint(classname, "dim", fallback=None)
mean_covar_factory = config.getobj(
classname, "mean_covar_factory", fallback=default_mean_covar_factory
)
mean, covar = mean_covar_factory(config)
max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None)
inducing_point_method_class = config.getobj(
classname, "inducing_point_method", fallback=AutoAllocator
)
# Check if allocator class has a `from_config` method
if hasattr(inducing_point_method_class, "from_config"):
inducing_point_method = inducing_point_method_class.from_config(config)
else:
inducing_point_method = inducing_point_method_class()
likelihood_cls = config.getobj(classname, "likelihood", fallback=None)
if likelihood_cls is not None:
if hasattr(likelihood_cls, "from_config"):
likelihood = likelihood_cls.from_config(config)
else:
likelihood = likelihood_cls()
else:
likelihood = None # fall back to __init__ default
optimizer_options = get_optimizer_options(config, classname)
return cls(
lb=lb,
ub=ub,
dim=dim,
inducing_size=inducing_size,
mean_module=mean,
covar_module=covar,
max_fit_time=max_fit_time,
inducing_point_method=inducing_point_method,
likelihood=likelihood,
optimizer_options=optimizer_options,
)
def _reset_hyperparameters(self) -> None:
"""Reset hyperparameters to their initial values."""
# warmstart_hyperparams affects hyperparams but not the variational strat,
# so we keep the old variational strat (which is only refreshed
# if warmstart_induc=False).
vsd = self.variational_strategy.state_dict() # type: ignore
vsd_hack = {f"variational_strategy.{k}": v for k, v in vsd.items()}
state_dict = deepcopy(self._fresh_state_dict)
state_dict.update(vsd_hack)
self.load_state_dict(state_dict)
self.likelihood.load_state_dict(self._fresh_likelihood_dict)
def _reset_variational_strategy(self) -> None:
if self.train_inputs is not None:
# remember original device
device = self.device
inducing_points = select_inducing_points(
allocator=self.inducing_point_method,
inducing_size=self.inducing_size,
covar_module=self.covar_module,
X=self.train_inputs[0],
bounds=self.bounds,
).to(device)
self.last_inducing_points_method = self.inducing_point_method.allocator_used
variational_distribution = CholeskyVariationalDistribution(
inducing_points.size(0), batch_shape=torch.Size([self._batch_size])
).to(device)
self.variational_strategy = VariationalStrategy(
self,
inducing_points,
variational_distribution,
learn_inducing_locations=False,
).to(device)
[docs] def fit(
self,
train_x: torch.Tensor,
train_y: torch.Tensor,
warmstart_hyperparams: bool = False,
warmstart_induc: bool = False,
**kwargs,
) -> None:
"""Fit underlying model.
Args:
train_x (torch.Tensor): Inputs.
train_y (torch.LongTensor): Responses.
warmstart_hyperparams (bool): Whether to reuse the previous hyperparameters (True) or fit from scratch
(False). Defaults to False.
warmstart_induc (bool): Whether to reuse the previous inducing points or fit from scratch (False).
Defaults to False.
"""
self.set_train_data(train_x, train_y)
# by default we reuse the model state and likelihood. If we
# want a fresh fit (no warm start), copy the state from class initialization.
if not warmstart_hyperparams:
self._reset_hyperparameters()
if not warmstart_induc or (
self.last_inducing_points_method == "DummyAllocator"
and self.inducing_point_method.__class__.__name__ != "DummyAllocator"
):
self._reset_variational_strategy()
n = train_y.shape[0]
mll = gpytorch.mlls.VariationalELBO(self.likelihood, self, n)
if "optimizer_kwargs" in kwargs:
self._fit_mll(mll, **kwargs)
else:
self._fit_mll(mll, optimizer_kwargs=self.optimizer_options, **kwargs)
[docs] def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
"""Sample from underlying model.
Args:
x (torch.Tensor): Points at which to sample.
num_samples (int): Number of samples to return.
kwargs are ignored
Returns:
torch.Tensor: Posterior samples [num_samples x dim]
"""
x = x.to(self.device)
return self.posterior(x).rsample(torch.Size([num_samples])).detach().squeeze()
[docs] def predict(
self, x: torch.Tensor, probability_space: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Query the model for posterior mean and variance.
Args:
x (torch.Tensor): Points at which to predict from the model.
probability_space (bool): Return outputs in units of
response probability instead of latent function value. Defaults to False.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points.
"""
with torch.no_grad():
x = x.to(self.device)
post = self.posterior(x)
fmean = post.mean.squeeze()
fvar = post.variance.squeeze()
if probability_space:
if isinstance(self.likelihood, BernoulliLikelihood):
# Probability-space mean and variance for Bernoulli-probit models is
# available in closed form, Proposition 1 in Letham et al. 2022 (AISTATS).
a_star = fmean / torch.sqrt(1 + fvar)
pmean = Normal(0, 1).cdf(a_star)
t_term = torch.tensor(
owens_t(
a_star.cpu().numpy(), 1 / np.sqrt(1 + 2 * fvar.cpu().numpy())
),
dtype=a_star.dtype,
).to(self.device)
pvar = pmean - 2 * t_term - pmean.square()
return promote_0d(pmean), promote_0d(pvar)
else:
fsamps = post.sample(torch.Size([10000]))
if hasattr(self.likelihood, "objective"):
psamps = self.likelihood.objective(fsamps)
else:
psamps = norm.cdf(fsamps)
pmean, pvar = psamps.mean(0), psamps.var(0)
return promote_0d(pmean), promote_0d(pvar)
else:
return promote_0d(fmean), promote_0d(fvar)
[docs] def predict_probability(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Query the model for posterior mean and variance in probability space.
Args:
x (torch.Tensor): Points at which to predict from the model.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points.
"""
return self.predict(x, probability_space=True)
[docs] def update(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs):
"""Perform a warm-start update of the model from previous fit.
Args:
train_x (torch.Tensor): Inputs.
train_y (torch.Tensor): Responses.
"""
return self.fit(
train_x, train_y, warmstart_hyperparams=True, warmstart_induc=True, **kwargs
)
[docs]class GPBetaRegressionModel(GPClassificationModel):
outcome_type = "percentage"
def __init__(
self,
lb: torch.Tensor,
ub: torch.Tensor,
inducing_point_method: InducingPointAllocator,
dim: Optional[int] = None,
mean_module: Optional[gpytorch.means.Mean] = None,
covar_module: Optional[gpytorch.kernels.Kernel] = None,
likelihood: Optional[Likelihood] = None,
inducing_size: Optional[int] = None,
max_fit_time: Optional[float] = None,
optimizer_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the GP Beta Regression model
Args:
lb (torch.Tensor): Lower bounds of the parameters.
ub (torch.Tensor): Upper bounds of the parameters.
inducing_point_method (InducingPointAllocator): The method to use to select the inducing points.
dim (int, optional): The number of dimensions in the parameter space. If None, it is inferred from the size
of lb and ub. Defaults to None.
mean_module (gpytorch.means.Mean, optional): GP mean class. Defaults to a constant with a normal prior. Defaults to None.
covar_module (gpytorch.kernels.Kernel, optional): GP covariance kernel class. Defaults to scaled RBF with a
gamma prior.
likelihood (gpytorch.likelihood.Likelihood, optional): The likelihood function to use. If None defaults to
Beta likelihood.
inducing_size (int, optional): Number of inducing points. Defaults to 100.
max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None,
there is no limit to the fitting time. Defaults to None.
"""
if likelihood is None:
likelihood = BetaLikelihood()
self.inducing_point_method = inducing_point_method
super().__init__(
lb=lb,
ub=ub,
dim=dim,
mean_module=mean_module,
covar_module=covar_module,
likelihood=likelihood,
inducing_size=inducing_size,
max_fit_time=max_fit_time,
inducing_point_method=inducing_point_method,
optimizer_options=optimizer_options,
)