#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Optional
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.objective import IdentityMCObjective, MCAcquisitionObjective
from botorch.models.model import Model
from .rejection_sampler import RejectionSampler
[docs]class MonotonicMCAcquisition(AcquisitionFunction):
"""
Acquisition function base class for use with the rejection sampling
monotonic GP. This handles the bookkeeping of the derivative
constraint points -- implement specific monotonic MC acquisition
in subclasses.
"""
def __init__(
self,
model: Model,
deriv_constraint_points: torch.Tensor,
num_samples: int = 32,
num_rejection_samples: int = 1024,
objective: Optional[MCAcquisitionObjective] = None,
) -> None:
"""Initialize MonotonicMCAcquisition
Args:
model (Model): Model to use, usually a MonotonicRejectionGP.
num_samples (int): Number of samples to keep from the rejection sampler. Defaults to 32.
num_rejection_samples (int): Number of rejection samples to draw. Defaults to 1024.
objective (MCAcquisitionObjective, optional): Objective transform of the GP output
before evaluating the acquisition. Defaults to identity transform.
"""
super().__init__(model=model)
self.deriv_constraint_points = deriv_constraint_points
self.num_samples = num_samples
self.num_rejection_samples = num_rejection_samples
self.sampler_shape = torch.Size([])
if objective is None:
assert model.num_outputs == 1
objective = IdentityMCObjective()
else:
assert isinstance(objective, MCAcquisitionObjective)
self.add_module("objective", objective)
[docs] def forward(self, X: torch.Tensor) -> torch.Tensor:
"""Evaluate the acquisition function at a set of points.
Args:
X (torch.Tensor): Points at which to evaluate the acquisition function.
Should be (b) x q x d, and q should be 1.
Returns:
torch.Tensor: Acquisition function value at these points.
"""
# This is currently doing joint samples over (b), and requiring q=1
# TODO T68656582 support batches properly.
if len(X.shape) == 3:
assert X.shape[1] == 1, "q must be 1"
Xfull = torch.cat((X[:, 0, :], self.deriv_constraint_points), dim=0)
else:
Xfull = torch.cat((X, self.deriv_constraint_points), dim=0)
if not hasattr(self, "sampler") or Xfull.shape != self.sampler_shape:
self._set_sampler(X.shape)
self.sampler_shape = Xfull.shape
posterior = self.model.posterior(Xfull)
samples = self.sampler(posterior)
assert len(samples.shape) == 3
# Drop derivative samples
samples = samples[:, : X.shape[0], :]
# NOTE: Squeeze below makes sure that we pass in the same `X` that was used
# to generate the `samples`. This is necessitated by `MCAcquisitionObjective`,
# which verifies that `samples` and `X` have the same q-batch size.
obj_samples = self.objective(samples, X=X.squeeze(-2) if X.ndim == 3 else X)
return self.acquisition(obj_samples)
def _set_sampler(self, Xshape: torch.Size) -> None:
"""
Sets up the rejection sampler for generating samples with derivative constraints.
Args:
Xshape (torch.Size): The shape of the input points `X` for which the sampler is set up.
"""
sampler = RejectionSampler(
num_samples=self.num_samples,
num_rejection_samples=self.num_rejection_samples,
constrained_idx=torch.arange(
Xshape[0], Xshape[0] + self.deriv_constraint_points.shape[0]
),
)
self.add_module("sampler", sampler)
[docs] def acquisition(self, obj_samples: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
[docs]class MonotonicMCLSE(MonotonicMCAcquisition):
def __init__(
self,
model: Model,
deriv_constraint_points: torch.Tensor,
target: float,
num_samples: int = 32,
num_rejection_samples: int = 1024,
beta: float = 3.84,
objective: Optional[MCAcquisitionObjective] = None,
) -> None:
"""Level set estimation acquisition function for use with monotonic models.
Args:
model (Model): Underlying model object, usually should be MonotonicRejectionGP.
deriv_constraint_points (torch.Tensor): Points at which the derivative should be constrained.
target (float): Level set value to target (after the objective).
num_samples (int): Number of MC samples to draw in MC acquisition. Defaults to 32.
num_rejection_samples (int): Number of rejection samples from which to subsample monotonic ones. Defaults to 1024.
beta (float): Parameter of the LSE acquisition function that governs exploration vs
exploitation (similarly to the same parameter in UCB). Defaults to 3.84 (1.96 ** 2), which maps to the straddle
heuristic of Bryan et al. 2005.
objective (MCAcquisitionObjective, optional): Objective transform. Defaults to identity transform.
"""
self.beta = beta
self.target = target
super().__init__(
model=model,
deriv_constraint_points=deriv_constraint_points,
num_samples=num_samples,
num_rejection_samples=num_rejection_samples,
objective=objective,
)
[docs] def acquisition(self, obj_samples: torch.Tensor) -> torch.Tensor:
"""
Computes the acquisition function value for level set estimation in monotonic models.
Args:
obj_samples (torch.Tensor): Tensor of samples from the model, transformed by the objective.
Expected shape is samples x batch_shape.
Returns:
torch.Tensor: The acquisition function value, calculated as the difference between an exploration-exploitation term
(based on the variance and `beta` parameter) and the absolute difference between the mean and the target level set.
"""
mean = obj_samples.mean(dim=0)
variance = obj_samples.var(dim=0)
# prevent numerical issues if probit makes all the values 1 or 0
variance = torch.clamp(variance, min=1e-5)
delta = torch.sqrt(self.beta * variance)
return delta - torch.abs(mean - self.target)