Source code for aepsych.acquisition.mc_posterior_variance
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from aepsych.acquisition.monotonic_rejection import MonotonicMCAcquisition
from aepsych.acquisition.objective import ProbitObjective
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor
[docs]def balv_acq(obj_samps: torch.Tensor) -> torch.Tensor:
    """Evaluate BALV (posterior variance) on a set of objective samples.
    Args:
        obj_samps (torch.Tensor): Samples from the GP, transformed by the objective.
            Should be samples x batch_shape.
    Returns:
        torch.Tensor: Acquisition function value.
    """
    # the output of objective is of shape num_samples x batch_shape x d_out
    # objective should project the last dimension to 1d,
    # so incoming should be samples x batch_shape, we take var in samp dim
    return obj_samps.var(dim=0).squeeze(-1) 
[docs]class MCPosteriorVariance(MCAcquisitionFunction):
    r"""Posterior variance, computed using samples so we can use objective/transform"""
    def __init__(
        self,
        model: Model,
        objective: Optional[MCAcquisitionObjective] = None,
        sampler: Optional[MCSampler] = None,
    ) -> None:
        r"""Posterior Variance of Link Function
        Args:
            model: A fitted model.
            objective: An MCAcquisitionObjective representing the link function
                (e.g., logistic or probit.) applied on the difference of (usually 1-d)
                two samples. Can be implemented via GenericMCObjective.
            sampler: The sampler used for drawing MC samples.
        """
        if sampler is None:
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]))
        if objective is None:
            objective = ProbitObjective()
        super().__init__(model=model, sampler=sampler, objective=None, X_pending=None)
        self.objective = objective
    @t_batch_mode_transform()
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate MCPosteriorVariance on the candidate set `X`.
        Args:
            X: A `batch_size x q x d`-dim Tensor
        Returns:
            Posterior variance of link function at X that active learning
            hopes to maximize
        """
        # the output is of shape batch_shape x q x d_out
        post = self.model.posterior(X)
        samples = self.sampler(post)  # num_samples x batch_shape x q x d_out
        return self.acquisition(self.objective(samples, X))
[docs]    def acquisition(self, obj_samples: torch.Tensor) -> torch.Tensor:
        # RejectionSampler drops the final dim so we reaugment it
        # here for compatibility with non-Monotonic MCAcquisition
        if len(obj_samples.shape) == 2:
            obj_samples = obj_samples[..., None]
        return balv_acq(obj_samples)  
@acqf_input_constructor(MCPosteriorVariance)
def construct_inputs(
    model,
    training_data,
    objective=None,
    sampler=None,
    **kwargs,
):
    return {
        "model": model,
        "objective": objective,
        "sampler": sampler,
    }
[docs]class MonotonicMCPosteriorVariance(MonotonicMCAcquisition):
[docs]    def acquisition(self, obj_samples: torch.Tensor) -> torch.Tensor:
        return balv_acq(obj_samples)