Source code for aepsych.acquisition.lse
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Union
import torch
from aepsych.acquisition.objective import ProbitObjective
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.monte_carlo import (
    MCAcquisitionFunction,
    MCAcquisitionObjective,
    MCSampler,
)
from botorch.models.model import Model
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor
[docs]class MCLevelSetEstimation(MCAcquisitionFunction):
    def __init__(
        self,
        model: Model,
        target: Union[float, Tensor] = 0.75,
        beta: Union[float, Tensor] = 3.84,
        objective: Optional[MCAcquisitionObjective] = None,
        sampler: Optional[MCSampler] = None,
    ) -> None:
        r"""Monte-carlo level set estimation.
        Args:
            model: A fitted model.
            target: the level set (after objective transform) to be estimated
            beta: a parameter that governs explore-exploit tradeoff
            objective: An MCAcquisitionObjective representing the link function
                (e.g., logistic or probit.) applied on the samples.
                Can be implemented via GenericMCObjective.
            sampler: The sampler used for drawing MC samples.
        """
        if sampler is None:
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]))
        if objective is None:
            objective = ProbitObjective()
        super().__init__(model=model, sampler=sampler, objective=None, X_pending=None)
        self.objective = objective
        self.beta = beta
        self.target = target
[docs]    def acquisition(self, obj_samples: torch.Tensor) -> torch.Tensor:
        """Evaluate the acquisition based on objective samples.
        Usually you should not call this directly unless you are
        subclassing this class and modifying how objective samples
        are generated.
        Args:
            obj_samples (torch.Tensor): Samples from the model, transformed
                by the objective. Should be samples x batch_shape.
        Returns:
            torch.Tensor: Acquisition function at the sampled values.
        """
        mean = obj_samples.mean(dim=0)
        variance = obj_samples.var(dim=0)
        # prevent numerical issues if probit makes all the values 1 or 0
        variance = torch.clamp(variance, min=1e-5)
        delta = torch.sqrt(self.beta * variance)
        return delta - torch.abs(mean - self.target) 
    @t_batch_mode_transform()
    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """Evaluate the acquisition function
        Args:
            X (torch.Tensor): Points at which to evaluate.
        Returns:
            torch.Tensor: Value of the acquisition functiona at these points.
        """
        post = self.model.posterior(X)
        samples = self.sampler(post)  # num_samples x batch_shape x q x d_out
        return self.acquisition(self.objective(samples, X)).squeeze(-1) 
@acqf_input_constructor(MCLevelSetEstimation)
def construct_inputs_lse(
    model,
    training_data,
    objective=None,
    target=0.75,
    beta=3.84,
    sampler=None,
    **kwargs,
):
    return {
        "model": model,
        "objective": objective,
        "target": target,
        "beta": beta,
        "sampler": sampler,
    }