Source code for aepsych.acquisition.mutual_information
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from typing import Any, Dict, Optional
import torch
from aepsych.acquisition.objective import ProbitObjective
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor
from torch.distributions.bernoulli import Bernoulli
[docs]def bald_acq(obj_samples: torch.Tensor) -> torch.Tensor:
"""Evaluate Mutual Information acquisition function.
With latent function F and X a hypothetical observation at a new point,
I(F; X) = I(X; F) = H(X) - H(X |F),
H(X |F ) = E_{f} (H(X |F =f )
i.e., we take the posterior entropy of the (Bernoulli) observation X given the
current model posterior and subtract the conditional entropy on F, that being
the mean entropy over the posterior for F. This is equivalent to the BALD
acquisition function in Houlsby et al. NeurIPS 2012.
Args:
obj_samples (torch.Tensor): Objective samples from the GP, of
shape num_samples x batch_shape x d_out
Returns:
torch.Tensor: Value of acquisition at samples.
"""
mean_p = obj_samples.mean(dim=0)
posterior_entropies = Bernoulli(mean_p).entropy().squeeze(-1)
sample_entropies = Bernoulli(obj_samples).entropy()
conditional_entropies = sample_entropies.mean(dim=0).squeeze(-1)
return posterior_entropies - conditional_entropies
@acqf_input_constructor(BernoulliMCMutualInformation)
def construct_inputs_mi(
model: Model,
training_data: None,
objective: Optional[MCAcquisitionObjective] = None,
sampler: Optional[MCSampler] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Constructs the input dictionary for initializing the BernoulliMCMutualInformation acquisition function.
Args:
model (Model): The fitted model to use.
training_data (None): Placeholder for compatibility; not used in this function.
objective (MCAcquisitionObjective, optional): Objective function for transforming samples (e.g., logit or probit).
sampler (MCSampler, optional): Sampler for Monte Carlo sampling; defaults to SobolQMCNormalSampler if not provided.
Returns:
Dict[str, Any]: Dictionary of constructed inputs for the BernoulliMCMutualInformation acquisition function.
"""
return {
"model": model,
"objective": objective,
"sampler": sampler,
}