Source code for aepsych.models.monotonic_rejection_gp

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

import gpytorch
import numpy as np
import torch
from aepsych.acquisition.rejection_sampler import RejectionSampler
from aepsych.config import Config
from aepsych.factory.factory import monotonic_mean_covar_factory
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
from aepsych.models.base import AEPsychMixin
from aepsych.models.utils import select_inducing_points
from aepsych.utils import _process_bounds, promote_0d
from botorch.fit import fit_gpytorch_mll
from gpytorch.kernels import Kernel
from gpytorch.likelihoods import BernoulliLikelihood, Likelihood
from gpytorch.means import Mean
from gpytorch.mlls.variational_elbo import VariationalELBO
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
from scipy.stats import norm
from torch import Tensor


[docs]class MonotonicRejectionGP(AEPsychMixin, ApproximateGP): """A monotonic GP using rejection sampling. This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP is likewise a GP) but instead of approximately optimizing the likelihood of the model using EP, we optimize an unconstrained model by VI and then draw monotonic samples by rejection sampling. References: Riihimäki, J., & Vehtari, A. (2010). Gaussian processes with monotonicity information. Journal of Machine Learning Research, 9, 645–652. """ _num_outputs = 1 stimuli_per_trial = 1 outcome_type = "binary" def __init__( self, monotonic_idxs: Sequence[int], lb: Union[np.ndarray, torch.Tensor], ub: Union[np.ndarray, torch.Tensor], dim: Optional[int] = None, mean_module: Optional[Mean] = None, covar_module: Optional[Kernel] = None, likelihood: Optional[Likelihood] = None, fixed_prior_mean: Optional[float] = None, num_induc: int = 25, num_samples: int = 250, num_rejection_samples: int = 5000, inducing_point_method: str = "auto", ) -> None: """Initialize MonotonicRejectionGP. Args: likelihood (str): Link function and likelihood. Can be 'probit-bernoulli' or 'identity-gaussian'. monotonic_idxs (List[int]): List of which columns of x should be given monotonicity constraints. fixed_prior_mean (Optional[float], optional): Fixed prior mean. If classification, should be the prior classification probability (not the latent function value). Defaults to None. covar_module (Optional[Kernel], optional): Covariance kernel to use (default: scaled RBF). mean_module (Optional[Mean], optional): Mean module to use (default: constant mean). num_induc (int, optional): Number of inducing points for variational GP.]. Defaults to 25. num_samples (int, optional): Number of samples for estimating posterior on preDict or acquisition function evaluation. Defaults to 250. num_rejection_samples (int, optional): Number of samples used for rejection sampling. Defaults to 4096. acqf (MonotonicMCAcquisition, optional): Acquisition function to use for querying points. Defaults to MonotonicMCLSE. objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli. extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None. """ self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim) if likelihood is None: likelihood = BernoulliLikelihood() self.inducing_size = num_induc self.inducing_point_method = inducing_point_method inducing_points = select_inducing_points( inducing_size=self.inducing_size, bounds=self.bounds, method="sobol", ) inducing_points_aug = self._augment_with_deriv_index(inducing_points, 0) variational_distribution = CholeskyVariationalDistribution( inducing_points_aug.size(0) ) variational_strategy = VariationalStrategy( model=self, inducing_points=inducing_points_aug, variational_distribution=variational_distribution, learn_inducing_locations=False, ) if mean_module is None: mean_module = ConstantMeanPartialObsGrad() if fixed_prior_mean is not None: if isinstance(likelihood, BernoulliLikelihood): fixed_prior_mean = norm.ppf(fixed_prior_mean) mean_module.constant.requires_grad_(False) mean_module.constant.copy_(torch.tensor(fixed_prior_mean)) if covar_module is None: ls_prior = gpytorch.priors.GammaPrior( concentration=4.6, rate=1.0, transform=lambda x: 1 / x ) ls_prior_mode = ls_prior.rate / (ls_prior.concentration + 1) ls_constraint = gpytorch.constraints.GreaterThan( lower_bound=1e-4, transform=None, initial_value=ls_prior_mode ) covar_module = gpytorch.kernels.ScaleKernel( RBFKernelPartialObsGrad( lengthscale_prior=ls_prior, lengthscale_constraint=ls_constraint, ard_num_dims=dim, ), outputscale_prior=gpytorch.priors.SmoothedBoxPrior(a=1, b=4), ) super().__init__(variational_strategy) self.bounds_ = torch.stack([self.lb, self.ub]) self.mean_module = mean_module self.covar_module = covar_module self.likelihood = likelihood self.num_induc = num_induc self.monotonic_idxs = monotonic_idxs self.num_samples = num_samples self.num_rejection_samples = num_rejection_samples self.fixed_prior_mean = fixed_prior_mean self.inducing_points = inducing_points
[docs] def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None: """Fit the model Args: train_x (Tensor): Training x points train_y (Tensor): Training y points. Should be (n x 1). """ self.set_train_data(train_x, train_y) self.inducing_points = select_inducing_points( inducing_size=self.inducing_size, covar_module=self.covar_module, X=self.train_inputs[0], bounds=self.bounds, method=self.inducing_point_method, ) self._set_model(train_x, train_y)
def _set_model( self, train_x: Tensor, train_y: Tensor, model_state_dict: Optional[Dict[str, Tensor]] = None, likelihood_state_dict: Optional[Dict[str, Tensor]] = None, ) -> None: train_x_aug = self._augment_with_deriv_index(train_x, 0) self.set_train_data(train_x_aug, train_y) # Set model parameters if model_state_dict is not None: self.load_state_dict(model_state_dict) if likelihood_state_dict is not None: self.likelihood.load_state_dict(likelihood_state_dict) # Fit! mll = VariationalELBO( likelihood=self.likelihood, model=self, num_data=train_y.numel() ) mll = fit_gpytorch_mll(mll)
[docs] def update(self, train_x: Tensor, train_y: Tensor, warmstart: bool = True) -> None: """ Update the model with new data. Expects the full set of data, not the incremental new data. Args: train_x (Tensor): Train X. train_y (Tensor): Train Y. Should be (n x 1). warmstart (bool): If True, warm-start model fitting with current parameters. """ if warmstart: model_state_dict = self.state_dict() likelihood_state_dict = self.likelihood.state_dict() else: model_state_dict = None likelihood_state_dict = None self._set_model( train_x=train_x, train_y=train_y, model_state_dict=model_state_dict, likelihood_state_dict=likelihood_state_dict, )
[docs] def sample( self, x: Tensor, num_samples: Optional[int] = None, num_rejection_samples: Optional[int] = None, ) -> torch.Tensor: """Sample from monotonic GP Args: x (Tensor): tensor of n points at which to sample num_samples (int, optional): how many points to sample (default: self.num_samples) Returns: a Tensor of shape [n_samp, n] """ if num_samples is None: num_samples = self.num_samples if num_rejection_samples is None: num_rejection_samples = self.num_rejection_samples rejection_ratio = 20 if num_samples * rejection_ratio > num_rejection_samples: warnings.warn( f"num_rejection_samples should be at least {rejection_ratio} times greater than num_samples." ) n = x.shape[0] # Augment with derivative index x_aug = self._augment_with_deriv_index(x, 0) # Add in monotonicity constraint points deriv_cp = self._get_deriv_constraint_points() x_aug = torch.cat((x_aug, deriv_cp), dim=0) assert x_aug.shape[0] == x.shape[0] + len( self.monotonic_idxs * self.inducing_points.shape[0] ) constrained_idx = torch.arange(n, x_aug.shape[0]) with torch.no_grad(): posterior = self.posterior(x_aug) sampler = RejectionSampler( num_samples=num_samples, num_rejection_samples=num_rejection_samples, constrained_idx=constrained_idx, ) samples = sampler(posterior) samples_f = samples[:, :n, 0].detach().cpu() return samples_f
[docs] def predict( self, x: Tensor, probability_space: bool = False ) -> Tuple[Tensor, Tensor]: """Predict Args: x: tensor of n points at which to predict. Returns: tuple (f, var) where f is (n,) and var is (n,) """ samples_f = self.sample(x) mean = torch.mean(samples_f, dim=0).squeeze() variance = torch.var(samples_f, dim=0).clamp_min(0).squeeze() if probability_space: return ( torch.Tensor(promote_0d(norm.cdf(mean))), torch.Tensor(promote_0d(norm.cdf(variance))), ) return mean, variance
[docs] def predict_probability( self, x: Union[torch.Tensor, np.ndarray] ) -> Tuple[torch.Tensor, torch.Tensor]: return self.predict(x, probability_space=True)
def _augment_with_deriv_index(self, x: Tensor, indx): return torch.cat( (x, indx * torch.ones(x.shape[0], 1)), dim=1, ) def _get_deriv_constraint_points(self): deriv_cp = torch.tensor([]) for i in self.monotonic_idxs: induc_i = self._augment_with_deriv_index(self.inducing_points, i + 1) deriv_cp = torch.cat((deriv_cp, induc_i), dim=0) return deriv_cp
[docs] @classmethod def from_config(cls, config: Config) -> MonotonicRejectionGP: classname = cls.__name__ num_induc = config.gettensor(classname, "num_induc", fallback=25) num_samples = config.gettensor(classname, "num_samples", fallback=250) num_rejection_samples = config.getint( classname, "num_rejection_samples", fallback=5000 ) lb = config.gettensor(classname, "lb") ub = config.gettensor(classname, "ub") dim = config.getint(classname, "dim", fallback=None) mean_covar_factory = config.getobj( classname, "mean_covar_factory", fallback=monotonic_mean_covar_factory ) mean, covar = mean_covar_factory(config) monotonic_idxs: List[int] = config.getlist( classname, "monotonic_idxs", fallback=[-1] ) return cls( monotonic_idxs=monotonic_idxs, lb=lb, ub=ub, dim=dim, num_induc=num_induc, num_samples=num_samples, num_rejection_samples=num_rejection_samples, mean_module=mean, covar_module=covar, )
[docs] def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal: """Evaluate GP Args: x (torch.Tensor): Tensor of points at which GP should be evaluated. Returns: gpytorch.distributions.MultivariateNormal: Distribution object holding mean and covariance at x. """ # final dim is deriv index, we only normalize the "real" dims transformed_x = x.clone() transformed_x[..., :-1] = self.normalize_inputs(transformed_x[..., :-1]) mean_x = self.mean_module(transformed_x) covar_x = self.covar_module(transformed_x) latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) return latent_pred