Source code for aepsych.models.gp_classification

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, Optional, Tuple

import gpytorch
import numpy as np
import torch
from aepsych.config import Config
from aepsych.factory.default import default_mean_covar_factory
from aepsych.models.base import AEPsychModelDeviceMixin
from aepsych.models.inducing_points import GreedyVarianceReduction
from aepsych.models.inducing_points.base import InducingPointAllocator
from aepsych.utils import get_dims, get_optimizer_options, promote_0d
from aepsych.utils_logging import getLogger
from gpytorch.likelihoods import BernoulliLikelihood, BetaLikelihood, Likelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
from scipy.special import owens_t
from scipy.stats import norm
from torch.distributions import Normal

logger = getLogger()


[docs]class GPClassificationModel(AEPsychModelDeviceMixin, ApproximateGP): """Probit-GP model with variational inference. From a conventional ML perspective this is a GP Classification model, though in the psychophysics context it can also be thought of as a nonlinear generalization of the standard linear model for 1AFC or yes/no trials. For more on variational inference, see e.g. https://docs.gpytorch.ai/en/v1.1.1/examples/04_Variational_and_Approximate_GPs/ """ _batch_size = 1 _num_outputs = 1 stimuli_per_trial = 1 outcome_type = "binary" def __init__( self, dim: int, mean_module: Optional[gpytorch.means.Mean] = None, covar_module: Optional[gpytorch.kernels.Kernel] = None, likelihood: Optional[Likelihood] = None, inducing_point_method: Optional[InducingPointAllocator] = None, inducing_size: int = 100, max_fit_time: Optional[float] = None, optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the GP Classification model Args: dim (int): The number of dimensions in the parameter space. mean_module (gpytorch.means.Mean, optional): GP mean class. Defaults to a constant with a normal prior. covar_module (gpytorch.kernels.Kernel, optional): GP covariance kernel class. Defaults to scaled RBF with a gamma prior. likelihood (gpytorch.likelihood.Likelihood, optional): The likelihood function to use. If None defaults to Bernouli likelihood. inducing_point_method (InducingPointAllocator, optional): The method to use for selecting inducing points. If not set, a GreedyVarianceReduction is made. inducing_size (int): Number of inducing points. Defaults to 100. max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None, there is no limit to the fitting time. optimizer_options (Dict[str, Any], optional): Optimizer options to pass to the SciPy optimizer during fitting. Assumes we are using L-BFGS-B. """ self.dim = dim self.max_fit_time = max_fit_time self.inducing_size = inducing_size self.optimizer_options = ( {"options": optimizer_options} if optimizer_options else {"options": {}} ) if self.inducing_size >= 100: logger.warning( ( "inducing_size in GPClassificationModel is >=100, more inducing points " "can lead to better fits but slower performance in general. Performance " "at >=100 inducing points is especially slow." ) ) if likelihood is None: likelihood = BernoulliLikelihood() if mean_module is None or covar_module is None: default_mean, default_covar = default_mean_covar_factory( dim=self.dim, stimuli_per_trial=self.stimuli_per_trial ) self.inducing_point_method = inducing_point_method or GreedyVarianceReduction( dim=self.dim ) inducing_points = self.inducing_point_method.allocate_inducing_points( num_inducing=self.inducing_size, covar_module=covar_module or default_covar, ) variational_distribution = CholeskyVariationalDistribution( inducing_points.size(0), batch_shape=torch.Size([self._batch_size]) ).to(inducing_points) variational_strategy = VariationalStrategy( self, inducing_points, variational_distribution, learn_inducing_locations=False, ) super().__init__(variational_strategy) self.likelihood = likelihood self.mean_module = mean_module or default_mean self.covar_module = covar_module or default_covar self._fresh_state_dict = deepcopy(self.state_dict()) self._fresh_likelihood_dict = deepcopy(self.likelihood.state_dict())
[docs] @classmethod def from_config(cls, config: Config) -> GPClassificationModel: """Alternate constructor for GPClassification model from a configuration. This is used when we recursively build a full sampling strategy from a configuration. TODO: document how this works in some tutorial. Args: config (Config): A configuration containing keys/values matching this class Returns: GPClassificationModel: Configured class instance. """ classname = cls.__name__ inducing_size = config.getint(classname, "inducing_size", fallback=100) dim = config.getint(classname, "dim", fallback=None) if dim is None: dim = get_dims(config) mean_covar_factory = config.getobj( classname, "mean_covar_factory", fallback=default_mean_covar_factory ) mean, covar = mean_covar_factory(config) max_fit_time = config.getfloat(classname, "max_fit_time", fallback=None) inducing_point_method_class = config.getobj( classname, "inducing_point_method", fallback=GreedyVarianceReduction ) # Check if allocator class has a `from_config` method if hasattr(inducing_point_method_class, "from_config"): inducing_point_method = inducing_point_method_class.from_config(config) else: inducing_point_method = inducing_point_method_class() likelihood_cls = config.getobj(classname, "likelihood", fallback=None) if likelihood_cls is not None: if hasattr(likelihood_cls, "from_config"): likelihood = likelihood_cls.from_config(config) else: likelihood = likelihood_cls() else: likelihood = None # fall back to __init__ default optimizer_options = get_optimizer_options(config, classname) return cls( dim=dim, inducing_size=inducing_size, mean_module=mean, covar_module=covar, max_fit_time=max_fit_time, inducing_point_method=inducing_point_method, likelihood=likelihood, optimizer_options=optimizer_options, )
def _reset_hyperparameters(self) -> None: """Reset hyperparameters to their initial values.""" # warmstart_hyperparams affects hyperparams but not the variational strat, # so we keep the old variational strat (which is only refreshed # if warmstart_induc=False). vsd = self.variational_strategy.state_dict() # type: ignore vsd_hack = {f"variational_strategy.{k}": v for k, v in vsd.items()} state_dict = deepcopy(self._fresh_state_dict) state_dict.update(vsd_hack) self.load_state_dict(state_dict) self.likelihood.load_state_dict(self._fresh_likelihood_dict) def _reset_variational_strategy(self) -> None: if self.train_inputs is not None: # remember original device device = self.device inducing_points = self.inducing_point_method.allocate_inducing_points( num_inducing=self.inducing_size, covar_module=self.covar_module, inputs=self.train_inputs[0], ).to(device) variational_distribution = CholeskyVariationalDistribution( inducing_points.size(0), batch_shape=torch.Size([self._batch_size]) ).to(device) self.variational_strategy = VariationalStrategy( self, inducing_points, variational_distribution, learn_inducing_locations=False, ).to(device)
[docs] def fit( self, train_x: torch.Tensor, train_y: torch.Tensor, warmstart_hyperparams: bool = False, warmstart_induc: bool = False, **kwargs, ) -> None: """Fit underlying model. Args: train_x (torch.Tensor): Inputs. train_y (torch.LongTensor): Responses. warmstart_hyperparams (bool): Whether to reuse the previous hyperparameters (True) or fit from scratch (False). Defaults to False. warmstart_induc (bool): Whether to reuse the previous inducing points or fit from scratch (False). Defaults to False. """ self.set_train_data(train_x, train_y) # by default we reuse the model state and likelihood. If we # want a fresh fit (no warm start), copy the state from class initialization. if not warmstart_hyperparams: self._reset_hyperparameters() if not warmstart_induc or ( self.inducing_point_method.last_allocator_used is None ): self._reset_variational_strategy() n = train_y.shape[0] mll = gpytorch.mlls.VariationalELBO(self.likelihood, self, n) if "optimizer_kwargs" in kwargs: self._fit_mll(mll, **kwargs) else: self._fit_mll(mll, optimizer_kwargs=self.optimizer_options, **kwargs)
[docs] def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor: """Sample from underlying model. Args: x (torch.Tensor): Points at which to sample. num_samples (int): Number of samples to return. kwargs are ignored Returns: torch.Tensor: Posterior samples [num_samples x dim] """ x = x.to(self.device) return self.posterior(x).rsample(torch.Size([num_samples])).detach().squeeze()
[docs] def predict( self, x: torch.Tensor, probability_space: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """Query the model for posterior mean and variance. Args: x (torch.Tensor): Points at which to predict from the model. probability_space (bool): Return outputs in units of response probability instead of latent function value. Defaults to False. Returns: Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. """ with torch.no_grad(): x = x.to(self.device) post = self.posterior(x) fmean = post.mean.squeeze() fvar = post.variance.squeeze() if probability_space: if isinstance(self.likelihood, BernoulliLikelihood): # Probability-space mean and variance for Bernoulli-probit models is # available in closed form, Proposition 1 in Letham et al. 2022 (AISTATS). a_star = fmean / torch.sqrt(1 + fvar) pmean = Normal(0, 1).cdf(a_star) t_term = torch.tensor( owens_t( a_star.cpu().numpy(), 1 / np.sqrt(1 + 2 * fvar.cpu().numpy()) ), dtype=a_star.dtype, ).to(self.device) pvar = pmean - 2 * t_term - pmean.square() return promote_0d(pmean), promote_0d(pvar) else: fsamps = post.sample(torch.Size([10000])) if hasattr(self.likelihood, "objective"): psamps = self.likelihood.objective(fsamps) else: psamps = norm.cdf(fsamps) pmean, pvar = psamps.mean(0), psamps.var(0) return promote_0d(pmean), promote_0d(pvar) else: return promote_0d(fmean), promote_0d(fvar)
[docs] def predict_probability(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Query the model for posterior mean and variance in probability space. Args: x (torch.Tensor): Points at which to predict from the model. Returns: Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. """ return self.predict(x, probability_space=True)
[docs] def update(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs): """Perform a warm-start update of the model from previous fit. Args: train_x (torch.Tensor): Inputs. train_y (torch.Tensor): Responses. """ return self.fit( train_x, train_y, warmstart_hyperparams=True, warmstart_induc=True, **kwargs )
[docs]class GPBetaRegressionModel(GPClassificationModel): outcome_type = "percentage" def __init__( self, dim: int, mean_module: Optional[gpytorch.means.Mean] = None, covar_module: Optional[gpytorch.kernels.Kernel] = None, likelihood: Optional[Likelihood] = None, inducing_point_method: Optional[InducingPointAllocator] = None, inducing_size: int = 100, max_fit_time: Optional[float] = None, optimizer_options: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the GP Beta Regression model Args: dim (int): The number of dimensions in the parameter space. mean_module (gpytorch.means.Mean, optional): GP mean class. Defaults to a constant with a normal prior. Defaults to None. covar_module (gpytorch.kernels.Kernel, optional): GP covariance kernel class. Defaults to scaled RBF with a gamma prior. likelihood (gpytorch.likelihood.Likelihood, optional): The likelihood function to use. If None defaults to Beta likelihood. inducing_point_method (InducingPointAllocator, optional): The method to use for selecting inducing points. If not set, a GreedyVarianceReduction is made. inducing_size (int): Number of inducing points. Defaults to 100. max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None, there is no limit to the fitting time. Defaults to None. """ if likelihood is None: likelihood = BetaLikelihood() super().__init__( dim=dim, mean_module=mean_module, covar_module=covar_module, likelihood=likelihood, inducing_size=inducing_size, max_fit_time=max_fit_time, inducing_point_method=inducing_point_method, optimizer_options=optimizer_options, )