#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import time
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Protocol, Tuple
import gpytorch
import torch
from aepsych.utils_logging import getLogger
from botorch.fit import fit_gpytorch_mll, fit_gpytorch_mll_scipy
from botorch.models.gpytorch import GPyTorchModel
from botorch.posteriors import GPyTorchPosterior
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls import MarginalLogLikelihood
logger = getLogger()
[docs]class ModelProtocol(Protocol):
@property
def _num_outputs(self) -> int:
pass
@property
def outcome_type(self) -> str:
pass
@property
def extremum_solver(self) -> str:
pass
@property
def train_inputs(self) -> torch.Tensor:
pass
@property
def lb(self) -> torch.Tensor:
pass
@property
def ub(self) -> torch.Tensor:
pass
@property
def bounds(self) -> torch.Tensor:
pass
@property
def dim(self) -> int:
pass
@property
def device(self) -> torch.device:
pass
[docs] def posterior(self, X: torch.Tensor) -> GPyTorchPosterior:
pass
[docs] def predict(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
pass
[docs] def predict_probability(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
pass
@property
def stimuli_per_trial(self) -> int:
pass
@property
def likelihood(self) -> Likelihood:
pass
[docs] def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
pass
def _get_extremum(
self,
extremum_type: str,
locked_dims: Optional[Mapping[int, List[float]]],
n_samples=1000,
) -> Tuple[float, torch.Tensor]:
pass
[docs] def dim_grid(self, gridsize: int = 30) -> torch.Tensor:
pass
[docs] def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any) -> None:
pass
[docs] def update(
self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any
) -> None:
pass
[docs] def p_below_threshold(
self, x: torch.Tensor, f_thresh: torch.Tensor
) -> torch.Tensor:
pass
[docs]class AEPsychMixin(GPyTorchModel):
"""Mixin class that provides AEPsych-specific utility methods."""
extremum_solver = "Nelder-Mead"
outcome_types: List[str] = []
train_inputs: Optional[Tuple[torch.Tensor]]
train_targets: Optional[torch.Tensor]
[docs] def set_train_data(
self,
inputs: Optional[torch.Tensor] = None,
targets: Optional[torch.Tensor] = None,
strict: bool = False,
):
"""
Set the training data for the model.
Args:
inputs (torch.Tensor, optional): The new training inputs.
targets (torch.Tensor, optional): The new training targets.
strict (bool): Default is False. Ignored, just for compatibility.
input transformers. TODO: actually use this arg or change input transforms
to not require it.
"""
if inputs is not None:
self.train_inputs = (inputs,)
if targets is not None:
self.train_targets = targets
[docs] def forward(self, x: torch.Tensor) -> gpytorch.distributions.MultivariateNormal:
"""Evaluate GP
Args:
x (torch.Tensor): Tensor of points at which GP should be evaluated.
Returns:
gpytorch.distributions.MultivariateNormal: Distribution object
holding mean and covariance at x.
"""
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
return pred
def _fit_mll(
self,
mll: MarginalLogLikelihood,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
optimizer: Callable = fit_gpytorch_mll_scipy,
**kwargs,
) -> None:
"""Fits the model by maximizing the marginal log likelihood.
Args:
mll (MarginalLogLikelihood): Marginal log likelihood object.
optimizer_kwargs (Dict[str, Any], optional): Keyword arguments for the optimizer.
optimizer (Callable): Optimizer to use. Defaults to fit_gpytorch_mll_scipy.
"""
self.train()
train_x, train_y = mll.model.train_inputs[0], mll.model.train_targets
optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs.copy()
max_fit_time = kwargs.pop("max_fit_time", self.max_fit_time)
if max_fit_time is not None:
if "options" not in optimizer_kwargs:
optimizer_kwargs["options"] = {}
# figure out how long evaluating a single samp
starttime = time.time()
_ = mll(self(train_x), train_y)
single_eval_time = (
time.time() - starttime + 1e-6
) # add an epsilon to avoid divide by zero
n_eval = int(max_fit_time / single_eval_time)
optimizer_kwargs["options"]["maxfun"] = n_eval
logger.info(f"fit maxfun is {n_eval}")
starttime = time.time()
res = fit_gpytorch_mll(
mll, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, **kwargs
)
return res
[docs] def p_below_threshold(
self, x: torch.Tensor, f_thresh: torch.Tensor
) -> torch.Tensor:
"""Compute the probability that the latent function is below a threshold.
Args:
x (torch.Tensor): Points at which to evaluate the probability.
f_thresh (torch.Tensor): Threshold value.
Returns:
torch.Tensor: Probability that the latent function is below the threshold.
"""
f, var = self.predict(x)
f_thresh = f_thresh.reshape(-1, 1)
f = f.reshape(1, -1)
var = var.reshape(1, -1)
z = (f_thresh - f) / var.sqrt()
return torch.distributions.Normal(0, 1).cdf(z) # Use PyTorch's CDF equivalent
[docs]class AEPsychModelDeviceMixin(AEPsychMixin):
_train_inputs: Optional[Tuple[torch.Tensor]]
_train_targets: Optional[torch.Tensor]
[docs] def set_train_data(
self,
inputs: Optional[torch.Tensor] = None,
targets: Optional[torch.Tensor] = None,
strict: bool = False,
) -> None:
"""Set the training data for the model.
Args:
inputs (torch.Tensor, optional): The new training inputs X.
targets (torch.Tensor, optional): The new training targets Y.
strict (bool): Whether to strictly enforce the device of the inputs and targets.
input transformers. TODO: actually use this arg or change input transforms
to not require it.
"""
# Move to same device to ensure the right device
if inputs is not None:
self._train_inputs = (inputs.to(self.device),)
if targets is not None:
self._train_targets = targets.to(self.device)
@property
def device(self) -> torch.device:
"""Get the device of the model.
Returns:
torch.device: Device of the model.
"""
# We assume all models have some parameters and all models will only use one device
# notice that this has no setting, don't let users set device, use .to().
return next(self.parameters()).device
@property
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:
"""Get the training inputs.
Returns:
Optional[Tuple[torch.Tensor]]: Training inputs.
"""
if self._train_inputs is None:
return None
# makes sure the tensors are on the right device, move in place
for input in self._train_inputs:
input.to(self.device)
return self._train_inputs
@train_inputs.setter
def train_inputs(self, train_inputs: Optional[Tuple[torch.Tensor]]) -> None:
"""Set the training inputs.
Args:
train_inputs (Tuple[torch.Tensor]): Training inputs.
"""
if train_inputs is None:
self._train_inputs = None
else:
# setting device on copy to not change original
train_inputs = deepcopy(train_inputs)
for input in train_inputs:
input.to(self.device)
self._train_inputs = train_inputs
@property
def train_targets(self) -> Optional[torch.Tensor]:
"""Get the training targets.
Returns:
Optional[torch.Tensor]: Training targets.
"""
if self._train_targets is None:
return None
# make sure the tensors are on the right device
self._train_targets = self._train_targets.to(self.device)
return self._train_targets
@train_targets.setter
def train_targets(self, train_targets: Optional[torch.Tensor]) -> None:
"""Set the training targets.
Args:
train_targets (torch.Tensor, optional): Training targets.
"""
if train_targets is None:
self._train_targets = None
else:
# setting device on copy to not change original
train_targets = deepcopy(train_targets).to(self.device)
self._train_targets = train_targets