#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Literal, cast, Dict, Optional, Tuple
import numpy as np
import torch
from aepsych.utils import make_scaled_sobol
from botorch.acquisition import AcquisitionFunction
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gpytorch import GPyTorchModel
from botorch.utils.transforms import t_batch_mode_transform
from scipy.stats import norm
from .lookahead_utils import (
approximate_lookahead_levelset_at_xstar,
lookahead_levelset_at_xstar,
lookahead_p_at_xstar,
)
[docs]def Hb(p: torch.Tensor) -> torch.Tensor:
"""
Binary entropy.
Args:
p (torch.Tensor): torch.Tensor of probabilities.
Returns: Binary entropy for each probability.
"""
epsilon = torch.tensor(np.finfo(float).eps).to(p)
p = torch.clamp(p, min=epsilon, max=1 - epsilon)
return -torch.nan_to_num(p * torch.log2(p) + (1 - p) * torch.log2(1 - p))
[docs]def MI_fn(Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
"""
Average mutual information.
H(p) - E_y*[H(p | y*)]
Args:
Px (torch.Tensor): (b x m) Level-set posterior before observation
P1 (torch.Tensor): (b x m) Level-set posterior given observation of 1
P0 (torch.Tensor): (b x m) Level-set posterior given observation of 0
py1 (torch.Tensor): (b x 1) Probability of observing 1
Returns: (b) torch.tensor of mutual information averaged over Xq.
"""
mi = Hb(Px) - py1 * Hb(P1) - (1 - py1) * Hb(P0)
return mi.sum(dim=-1)
[docs]def ClassErr(p: torch.Tensor) -> torch.Tensor:
"""
Expected classification error, min(p, 1-p).
Args:
p (torch.Tensor): torch.Tensor of predicted probabilities.
Returns:
torch.Tensor: Expected classification error for each probability, computed as min(p, 1 - p).
"""
return torch.min(p, 1 - p)
[docs]def SUR_fn(Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
"""
Stepwise uncertainty reduction.
Expected reduction in expected classification error given observation at Xstar,
averaged over Xq.
Args:
Px (torch.Tensor): (b x m) Level-set posterior before observation
P1 (torch.Tensor): (b x m) Level-set posterior given observation of 1
P0 (torch.Tensor): (b x m) Level-set posterior given observation of 0
py1 (torch.Tensor): (b x 1) Probability of observing 1
Returns:
(b) torch.Tensor of SUR values.
"""
P1 = P1.to(Px)
py1 = py1.to(Px)
sur = ClassErr(Px) - py1 * ClassErr(P1) - (1 - py1) * ClassErr(P0)
return sur.sum(dim=-1)
[docs]def EAVC_fn(Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
"""
Expected absolute value change.
Expected absolute change in expected level-set volume given observation at Xstar.
Args:
Px (torch.Tensor): (b x m) Level-set posterior before observation
P1 (torch.Tensor): (b x m) Level-set posterior given observation of 1
P0 (torch.Tensor): (b x m) Level-set posterior given observation of 0
py1 (torch.Tensor): (b x 1) Probability of observing 1
Returns:
(b) torch.Tensor of EAVC values.
"""
avc1 = torch.abs((Px - P1).sum(dim=-1))
avc0 = torch.abs((Px - P0).sum(dim=-1))
return py1.squeeze(-1) * avc1 + (1 - py1).squeeze(-1) * avc0
[docs]class LookaheadAcquisitionFunction(AcquisitionFunction):
def __init__(
self,
model: GPyTorchModel,
target: Optional[float],
lookahead_type: Literal["levelset", "posterior"] = "levelset",
) -> None:
"""
A localized look-ahead acquisition function.
Args:
model (GPyTorchModel): The gpytorch model to use.
target (float, optional): Threshold value to target in p-space.
lookahead_type (Literal["levelset", "posterior"]): The type of look-ahead to perform (default is "levelset").
- If the lookahead_type is "levelset", the acqf will consider the posterior probability that a point is above or below the target level set.
- If the lookahead_type is "posterior", the acqf will consider the posterior probability that a point will be detected or not.
"""
super().__init__(model=model)
if lookahead_type == "levelset":
self.lookahead_fn = lookahead_levelset_at_xstar
assert target is not None, "Need a target for levelset lookahead!"
self.gamma = norm.ppf(target)
elif lookahead_type == "posterior":
self.lookahead_fn = lookahead_p_at_xstar # type: ignore
self.gamma = None
else:
raise RuntimeError(f"Got unknown lookahead type {lookahead_type}!")
## Local look-ahead acquisitions
[docs]class LocalLookaheadAcquisitionFunction(LookaheadAcquisitionFunction):
def __init__(
self,
model: GPyTorchModel,
lookahead_type: Literal["levelset", "posterior"] = "levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
) -> None:
"""
A localized look-ahead acquisition function.
Args:
model (GPyTorchModel): The gpytorch model to use.
lookahead_type (Literal["levelset", "posterior"]): The type of look-ahead to perform (default is "levelset").
- If the lookahead_type is "levelset", the acqf will consider the posterior probability that a point is above or below the target level set.
- If the lookahead_type is "posterior", the acqf will consider the posterior probability that a point will be detected or not.
target (float, optional): Threshold value to target in p-space.
posterior_transform (PosteriorTransform, optional): Optional transformation to apply to the posterior. Default: None.
"""
super().__init__(model=model, target=target, lookahead_type=lookahead_type)
self.posterior_transform = posterior_transform
@t_batch_mode_transform(expected_q=1)
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Evaluate acquisition function at X.
Args:
X (torch.Tensor): (b x 1 x d) point at which to evalaute acquisition function.
Returns:
(b) torch.Tensor of acquisition values.
"""
Px, P1, P0, py1 = self.lookahead_fn(
model=self.model,
Xstar=X,
Xq=X,
gamma=self.gamma,
posterior_transform=self.posterior_transform,
) # Return shape here has m=1.
return self._compute_acqf(Px, P1, P0, py1)
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
[docs]class LocalMI(LocalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
return MI_fn(Px, P1, P0, py1)
[docs]class LocalSUR(LocalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
return SUR_fn(Px, P1, P0, py1)
@acqf_input_constructor(LocalMI, LocalSUR)
def construct_inputs_local_lookahead(
model: GPyTorchModel,
training_data: None,
lookahead_type: Literal["levelset", "posterior"] = "levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs,
) -> Dict[str, Any]:
"""Constructs the input dictionary for initializing local lookahead acquisition functions.
Args:
model (GPyTorchModel): The gpytorch model to use.
training_data (None): Placeholder for compatibility; not used in this function.
lookahead_type (Literal["levelset", "posterior"]): Type of look-ahead to perform. Default is "levelset".
- If the lookahead_type is "levelset", the acqf will consider the posterior probability that a point is above or below the target level set.
- If the lookahead_type is "posterior", the acqf will consider the posterior probability that a point will be detected or not.
target (float, optional): Target threshold value in probability space. Default is None.
posterior_transform (PosteriorTransform, optional): Optional transformation to apply to the posterior. Default is None.
Returns:
Dict[str, Any]: Dictionary of constructed inputs for local lookahead acquisition functions.
"""
return {
"model": model,
"lookahead_type": lookahead_type,
"target": target,
"posterior_transform": posterior_transform,
}
## Global look-ahead acquisitions
[docs]class GlobalLookaheadAcquisitionFunction(LookaheadAcquisitionFunction):
def __init__(
self,
model: GPyTorchModel,
lookahead_type: Literal["levelset", "posterior"] = "levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[torch.Tensor] = None,
) -> None:
"""
A global look-ahead acquisition function.
Args:
model (GPyTorchModel): The gpytorch model to use.
lookahead_type (Literal["levelset", "posterior"]): The type of look-ahead to perform (default is "levelset").
- If the lookahead_type is "levelset", the acqf will consider the posterior probability that a point is above or below the target level set.
- If the lookahead_type is "posterior", the acqf will consider the posterior probability that a point will be detected or not.
target (float, optional): Threshold value to target in p-space.
posterior_transform (PosteriorTransform, optional): Optional transformation to apply to the posterior.
query_set_size (int, optional): Number of points in the query set.
Xq (torch.Tensor, optional): (m x d) global reference set.
"""
super().__init__(model=model, target=target, lookahead_type=lookahead_type)
self.posterior_transform = posterior_transform
assert (
Xq is not None or query_set_size is not None
), "Must pass either query set size or a query set!"
if Xq is not None and query_set_size is not None:
assert Xq.shape[0] == query_set_size, (
"If passing both Xq and query_set_size,"
+ "first dim of Xq should be query_set_size, got {Xq.shape[0]} != {query_set_size}"
)
if Xq is None:
# cast to an int in case we got a float from Config, which
# would raise on make_scaled_sobol
query_set_size = cast(int, query_set_size) # make mypy happy
assert int(query_set_size) == query_set_size # make sure casting is safe
# if the asserts above pass and Xq is None, query_set_size is not None so this is safe
query_set_size = int(query_set_size) # cast
Xq = make_scaled_sobol(model.lb, model.ub, query_set_size)
self.register_buffer("Xq", Xq)
@t_batch_mode_transform(expected_q=1)
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Evaluate acquisition function at X.
Args:
X (torch.Tensor): (b x 1 x d) point at which to evalaute acquisition function.
Returns:
(b) torch.Tensor of acquisition values.
"""
Px, P1, P0, py1 = self._get_lookahead_posterior(X)
return self._compute_acqf(Px, P1, P0, py1)
def _get_lookahead_posterior(
self, X: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Xq_batch = self.Xq.expand(X.shape[0], *self.Xq.shape)
return self.lookahead_fn(
model=self.model,
Xstar=X,
Xq=Xq_batch,
gamma=self.gamma,
posterior_transform=self.posterior_transform,
)
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
[docs]class GlobalMI(GlobalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
return MI_fn(Px, P1, P0, py1)
[docs]class GlobalSUR(GlobalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
return SUR_fn(Px, P1, P0, py1)
[docs]class ApproxGlobalSUR(GlobalSUR):
def __init__(
self,
model: GPyTorchModel,
lookahead_type = "levelset",
target: Optional[float] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[torch.Tensor] = None,
) -> None:
"""
An approximate global look-ahead acquisition function.
Args:
model (GPyTorchModel): The gpytorch model to use.
lookahed_type (str): The type of look-ahead to perform (default is "levelset").
target (float, optional): Threshold value to target in p-space.
query_set_size (int, optional): Number of points in the query set.
Xq (torch.Tensor, optional): (m x d) global reference set.
"""
assert (
lookahead_type == "levelset"
), f"ApproxGlobalSUR only supports lookahead on level set, got {lookahead_type}!"
super().__init__(
model=model,
target=target,
lookahead_type=lookahead_type,
query_set_size=query_set_size,
Xq=Xq,
)
def _get_lookahead_posterior(
self, X: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the look-ahead posterior distribution for given points.
Args:
X (torch.Tensor): Input torch.Tensor representing the points at which to evaluate the look-ahead posterior.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing tensors corresponding to the posterior's computed values.
"""
Xq_batch = self.Xq.expand(X.shape[0], *self.Xq.shape)
return approximate_lookahead_levelset_at_xstar(
model=self.model,
Xstar=X,
Xq=Xq_batch,
gamma=self.gamma,
posterior_transform=self.posterior_transform,
)
[docs]class EAVC(GlobalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
return EAVC_fn(Px, P1, P0, py1)
[docs]class MOCU(GlobalLookaheadAcquisitionFunction):
"""
MOCU acquisition function given in expr. 4 of:
Zhao, Guang, et al. "Uncertainty-aware active learning for optimal Bayesian classifier."
International Conference on Learning Representations (ICLR) 2021.
"""
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
"""
Computes the MOCU (Measure of Class Uncertainty) acquisition function.
Args:
Px (torch.Tensor): Tensor representing the current probability of class prediction.
P1 (torch.Tensor): Tensor of look-ahead predictions for the positive class.
P0 (torch.Tensor): Tensor of look-ahead predictions for the negative class.
py1 (torch.Tensor): Tensor representing the probability of the positive class.
Returns:
torch.Tensor: Expected value of the look-ahead uncertainty reduction for each query point.
"""
current_max_query = torch.maximum(Px, 1 - Px)
# expectation w.r.t. y* of the max of pq
lookahead_pq1_max = torch.maximum(P1, 1 - P1)
lookahead_pq0_max = torch.maximum(P0, 1 - P0)
lookahead_max_query = lookahead_pq1_max * py1 + lookahead_pq0_max * (1 - py1)
return (lookahead_max_query - current_max_query).mean(-1)
[docs]class SMOCU(GlobalLookaheadAcquisitionFunction):
"""
SMOCU acquisition function given in expr. 11 of:
Zhao, Guang, et al. "Bayesian active learning by soft mean objective cost of uncertainty."
International Conference on Artificial Intelligence and Statistics (AISTATS) 2021.
"""
def __init__(
self,
model: GPyTorchModel,
lookahead_type="posterior",
target: Optional[float] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[torch.Tensor] = None,
k: Optional[float] = 20.0,
) -> None:
"""
model (GPyTorchModel): The gpytorch model to use.
lookahead_type (str): The type of look-ahead to perform (default is "posterior").
target (float, optional): Threshold value to target in p-space. Default is None.
query_set_size (int, optional): Number of points in the query set. Default is 256.
Xq (torch.Tensor, optional): (m x d) global reference set. Default is None.
k (float, optional): Scaling factor for the softmax approximation, controlling the "softness" of the maximum operation. Default is 20.0.
"""
super().__init__(
model=model,
target=target,
lookahead_type=lookahead_type,
query_set_size=query_set_size,
Xq=Xq,
)
self.k = k
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
"""
Computes the SMOCU acquisition function.
Args:
Px (torch.Tensor): Tensor representing the current probability of class prediction.
P1 (torch.Tensor): Tensor of look-ahead predictions for the positive class.
P0 (torch.Tensor): Tensor of look-ahead predictions for the negative class.
py1 (torch.Tensor): Tensor representing the probability of the positive class.
Returns:
torch.Tensor: Expected reduction in uncertainty, based on the softmax approximation of the maximum query.
"""
current_softmax_query = (
torch.logsumexp(self.k * torch.stack((Px, 1 - Px), dim=-1), dim=-1) / self.k
)
# expectation w.r.t. y* of the max of pq
lookahead_pq1_softmax = (
torch.logsumexp(self.k * torch.stack((P1, 1 - P1), dim=-1), dim=-1) / self.k
)
lookahead_pq0_softmax = (
torch.logsumexp(self.k * torch.stack((P0, 1 - P0), dim=-1), dim=-1) / self.k
)
lookahead_softmax_query = (
lookahead_pq1_softmax * py1 + lookahead_pq0_softmax * (1 - py1)
)
return (lookahead_softmax_query - current_softmax_query).mean(-1)
[docs]class BEMPS(GlobalLookaheadAcquisitionFunction):
"""
BEMPS acquisition function given in:
Tan, Wei, et al. "Diversity Enhanced Active Learning with Strictly Proper Scoring Rules."
Advances in Neural Information Processing Systems 34 (2021).
"""
def __init__(self, scorefun: Callable, *args, **kwargs) -> None:
"""
scorefun (Callable): Scoring function to use for the BEMPS acquisition function.
"""
super().__init__(*args, **kwargs)
self.scorefun = scorefun
def _compute_acqf(self, Px: torch.Tensor, P1: torch.Tensor, P0: torch.Tensor, py1: torch.Tensor) -> torch.Tensor:
"""
Computes the BEMPS acquisition function.
Args:
Px (torch.Tensor): Tensor representing the current probability of class prediction.
P1 (torch.Tensor): Tensor of look-ahead predictions for the positive class.
P0 (torch.Tensor): Tensor of look-ahead predictions for the negative class.
py1 (torch.Tensor): Tensor representing the probability of the positive class.
Returns:
torch.Tensor: Expected improvement in the scoring function values, based on the current and look-ahead states.
"""
current_score = self.scorefun(Px)
lookahead_pq1_score = self.scorefun(P1)
lookahead_pq0_score = self.scorefun(P0)
lookahead_expected_score = lookahead_pq1_score * py1 + lookahead_pq0_score * (
1 - py1
)
return (lookahead_expected_score - current_score).mean(-1)
@acqf_input_constructor(GlobalMI, GlobalSUR, ApproxGlobalSUR, EAVC, MOCU, SMOCU, BEMPS)
def construct_inputs_global_lookahead(
model: GPyTorchModel,
training_data: None,
lookahead_type="levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Constructs the input dictionary for initializing global lookahead acquisition functions.
Args:
model (GPyTorchModel): The gpytorch model to use.
training_data (None): Placeholder for compatibility; not used in this function.
lookahead_type (str): Type of look-ahead to perform. Default is "levelset".
target (float, optional): Target threshold value in probability space. Default is None.
posterior_transform (PosteriorTransform, optional): Optional transformation to apply to the posterior. Default is None.
query_set_size (int, optional): Number of points in the query set. Default is 256.
Xq (torch.Tensor, optional): (m x d) global reference set. If not provided, a Sobol sequence is generated. Default is None.
Returns:
Dict[str, Any]: Dictionary of constructed inputs for global lookahead acquisition functions.
"""
lb = torch.tensor([bounds[0] for bounds in kwargs["bounds"]])
ub = torch.tensor([bounds[1] for bounds in kwargs["bounds"]])
if Xq is None and query_set_size is None:
raise ValueError("Either Xq or query_set_size must be provided.")
if Xq is None and query_set_size is not None:
Xq = make_scaled_sobol(lb, ub, query_set_size)
return {
"model": model,
"lookahead_type": lookahead_type,
"target": target,
"posterior_transform": posterior_transform,
"query_set_size": query_set_size,
"Xq": Xq,
}