Source code for aepsych.acquisition.mc_posterior_variance

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch
from aepsych.acquisition.monotonic_rejection import MonotonicMCAcquisition
from aepsych.acquisition.objective import ProbitObjective
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform


[docs]def balv_acq(obj_samps: torch.Tensor) -> torch.Tensor: """Evaluate BALV (posterior variance) on a set of objective samples. Args: obj_samps (torch.Tensor): Samples from the GP, transformed by the objective. Should be samples x batch_shape. Returns: torch.Tensor: Acquisition function value. """ # the output of objective is of shape num_samples x batch_shape x d_out # objective should project the last dimension to 1d, # so incoming should be samples x batch_shape, we take var in samp dim return obj_samps.var(dim=0).squeeze(-1)
[docs]class MCPosteriorVariance(MCAcquisitionFunction): r"""Posterior variance, computed using samples so we can use objective/transform""" def __init__( self, model: Model, objective: Optional[MCAcquisitionObjective] = None, sampler: Optional[MCSampler] = None, ) -> None: r"""Posterior Variance of Link Function Args: model (Model): A fitted model. objective (MCAcquisitionObjective optional): An MCAcquisitionObjective representing the link function (e.g., logistic or probit.) applied on the difference of (usually 1-d) two samples. Can be implemented via GenericMCObjective. Defaults tp ProbitObjective. sampler (MCSampler, optional): The sampler used for drawing MC samples. Defaults to SobolQMCNormalSampler. """ if sampler is None: sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512])) if objective is None: objective = ProbitObjective() super().__init__(model=model, sampler=sampler, objective=None, X_pending=None) self.objective = objective @t_batch_mode_transform() def forward(self, X: torch.Tensor) -> torch.Tensor: r"""Evaluate MCPosteriorVariance on the candidate set `X`. Args: X (torch.Tensor): A `batch_size x q x d`-dim Tensor Returns: torch.Tensor: Posterior variance of link function at X that active learning hopes to maximize """ # the output is of shape batch_shape x q x d_out post = self.model.posterior(X) samples = self.sampler(post) # num_samples x batch_shape x q x d_out return self.acquisition(self.objective(samples, X))
[docs] def acquisition(self, obj_samples: torch.Tensor) -> torch.Tensor: """Evaluate the acquisition based on objective samples. Args: obj_samples (torch.Tensor): Samples from the GP, transformed by the objective. Should be samples x batch_shape. Returns: torch.Tensor: Acquisition function at the sampled values. """ # RejectionSampler drops the final dim so we reaugment it # here for compatibility with non-Monotonic MCAcquisition if len(obj_samples.shape) == 2: obj_samples = obj_samples[..., None] return balv_acq(obj_samples)
@acqf_input_constructor(MCPosteriorVariance) def construct_inputs( model: Model, training_data: None, objective: Optional[MCAcquisitionObjective] = None, sampler: Optional[MCSampler] = None, **kwargs, ) -> Dict[str, Any]: """ Constructs the input dictionary for initializing the MCPosteriorVariance acquisition function. Args: model (Model): The fitted model to be used. training_data (None): Placeholder for compatibility; not used in this function. objective (MCAcquisitionObjective, optional): Objective function for transforming samples (e.g., logistic or probit). sampler (MCSampler, optional): Sampler for Monte Carlo sampling; defaults to SobolQMCNormalSampler if not provided. Returns: Dict[str, Any]: Dictionary of constructed inputs for the MCPosteriorVariance acquisition function. """ return { "model": model, "objective": objective, "sampler": sampler, }
[docs]class MonotonicMCPosteriorVariance(MonotonicMCAcquisition):
[docs] def acquisition(self, obj_samples: torch.Tensor) -> torch.Tensor: """ Evaluates the acquisition function value for monotonic posterior variance. Args: obj_samples (torch.Tensor): Samples from the GP, transformed by the objective. Should have shape samples x batch_shape. Returns: torch.Tensor: The BALV acquisition function value, representing the posterior variance calculated over the sample dimension. """ return balv_acq(obj_samples)