#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from copy import deepcopy
from typing import Dict, Optional, Tuple, Union
import gpytorch
import numpy as np
import torch
from aepsych.config import Config
from aepsych.factory.default import default_mean_covar_factory
from aepsych.models.base import AEPsychMixin
from aepsych.utils import _process_bounds, promote_0d
from aepsych.utils_logging import getLogger
from gpytorch.likelihoods import GaussianLikelihood, Likelihood
from gpytorch.models import ExactGP
logger = getLogger()
[docs]class GPRegressionModel(AEPsychMixin, ExactGP):
"""GP Regression model for continuous outcomes, using exact inference."""
_num_outputs = 1
_batch_size = 1
stimuli_per_trial = 1
outcome_type = "continuous"
def __init__(
self,
lb: Union[np.ndarray, torch.Tensor],
ub: Union[np.ndarray, torch.Tensor],
dim: Optional[int] = None,
mean_module: Optional[gpytorch.means.Mean] = None,
covar_module: Optional[gpytorch.kernels.Kernel] = None,
likelihood: Optional[Likelihood] = None,
max_fit_time: Optional[float] = None,
):
"""Initialize the GP regression model
Args:
lb (Union[numpy.ndarray, torch.Tensor]): Lower bounds of the parameters.
ub (Union[numpy.ndarray, torch.Tensor]): Upper bounds of the parameters.
dim (int, optional): The number of dimensions in the parameter space. If None, it is inferred from the size
of lb and ub.
mean_module (gpytorch.means.Mean, optional): GP mean class. Defaults to a constant with a normal prior.
covar_module (gpytorch.kernels.Kernel, optional): GP covariance kernel class. Defaults to scaled RBF with a
gamma prior.
likelihood (gpytorch.likelihood.Likelihood, optional): The likelihood function to use. If None defaults to
Gaussian likelihood.
max_fit_time (float, optional): The maximum amount of time, in seconds, to spend fitting the model. If None,
there is no limit to the fitting time.
"""
if likelihood is None:
likelihood = GaussianLikelihood()
super().__init__(None, None, likelihood)
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
self.max_fit_time = max_fit_time
if mean_module is None or covar_module is None:
default_mean, default_covar = default_mean_covar_factory(
dim=self.dim, stimuli_per_trial=self.stimuli_per_trial
)
self.mean_module = mean_module or default_mean
self.covar_module = covar_module or default_covar
self._fresh_state_dict = deepcopy(self.state_dict())
self._fresh_likelihood_dict = deepcopy(self.likelihood.state_dict())
[docs] @classmethod
def from_config(cls, config: Config) -> GPRegressionModel:
"""Alternate constructor for GP regression model.
This is used when we recursively build a full sampling strategy
from a configuration. TODO: document how this works in some tutorial.
Args:
config (Config): A configuration containing keys/values matching this class
Returns:
GPRegressionModel: Configured class instance.
"""
args = cls.construct_inputs(config)
return cls(**args)
[docs] def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs) -> None:
"""Fit underlying model.
Args:
train_x (torch.Tensor): Inputs.
train_y (torch.LongTensor): Responses.
"""
self.set_train_data(train_x, train_y)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)
return self._fit_mll(mll, **kwargs)
[docs] def sample(
self, x: Union[torch.Tensor, np.ndarray], num_samples: int
) -> torch.Tensor:
"""Sample from underlying model.
Args:
x (torch.Tensor): Points at which to sample.
num_samples (int, optional): Number of samples to return. Defaults to None.
kwargs are ignored
Returns:
torch.Tensor: Posterior samples [num_samples x dim]
"""
return self.posterior(x).rsample(torch.Size([num_samples])).detach().squeeze()
[docs] def update(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs):
"""Perform a warm-start update of the model from previous fit."""
return self.fit(train_x, train_y, **kwargs)
[docs] def predict(
self, x: Union[torch.Tensor, np.ndarray], **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Query the model for posterior mean and variance.
Args:
x (torch.Tensor): Points at which to predict from the model.
probability_space (bool, optional): Return outputs in units of
response probability instead of latent function value. Defaults to False.
Returns:
Tuple[np.ndarray, np.ndarray]: Posterior mean and variance at queries points.
"""
with torch.no_grad():
post = self.posterior(x)
fmean = post.mean.squeeze()
fvar = post.variance.squeeze()
return promote_0d(fmean), promote_0d(fvar)