# Source code for aepsych.acquisition.lookahead

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, cast

import numpy as np
import torch
from aepsych.utils import make_scaled_sobol
from botorch.acquisition import AcquisitionFunction
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gpytorch import GPyTorchModel
from botorch.utils.transforms import t_batch_mode_transform
from scipy.stats import norm
from torch import Tensor

from .lookahead_utils import (
approximate_lookahead_levelset_at_xstar,
lookahead_levelset_at_xstar,
lookahead_p_at_xstar,
)

[docs]def Hb(p: Tensor):
"""
Binary entropy.

Args:
p: Tensor of probabilities.

Returns: Binary entropy for each probability.
"""
epsilon = torch.tensor(np.finfo(float).eps)
p = torch.clamp(p, min=epsilon, max=1 - epsilon)
return -torch.nan_to_num(p * torch.log2(p) + (1 - p) * torch.log2(1 - p))

[docs]def MI_fn(Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
"""
Average mutual information.
H(p) - E_y*[H(p | y*)]

Args:
Px: (b x m) Level-set posterior before observation
P1: (b x m) Level-set posterior given observation of 1
P0: (b x m) Level-set posterior given observation of 0
py1: (b x 1) Probability of observing 1

Returns: (b) tensor of mutual information averaged over Xq.
"""
mi = Hb(Px) - py1 * Hb(P1) - (1 - py1) * Hb(P0)
return mi.sum(dim=-1)

[docs]def ClassErr(p: Tensor) -> Tensor:
"""
Expected classification error, min(p, 1-p).
"""
return torch.min(p, 1 - p)

[docs]def SUR_fn(Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
"""
Stepwise uncertainty reduction.

Expected reduction in expected classification error given observation at Xstar,
averaged over Xq.

Args:
Px: (b x m) Level-set posterior before observation
P1: (b x m) Level-set posterior given observation of 1
P0: (b x m) Level-set posterior given observation of 0
py1: (b x 1) Probability of observing 1

Returns: (b) tensor of SUR values.
"""
sur = ClassErr(Px) - py1 * ClassErr(P1) - (1 - py1) * ClassErr(P0)
return sur.sum(dim=-1)

[docs]def EAVC_fn(Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
"""
Expected absolute value change.

Expected absolute change in expected level-set volume given observation at Xstar.

Args:
Px: (b x m) Level-set posterior before observation
P1: (b x m) Level-set posterior given observation of 1
P0: (b x m) Level-set posterior given observation of 0
py1: (b x 1) Probability of observing 1

Returns: (b) tensor of EAVC values.
"""
avc1 = torch.abs((Px - P1).sum(dim=-1))
avc0 = torch.abs((Px - P0).sum(dim=-1))
return py1.squeeze(-1) * avc1 + (1 - py1).squeeze(-1) * avc0

[docs]class LookaheadAcquisitionFunction(AcquisitionFunction):
def __init__(
self,
model: GPyTorchModel,
target: Optional[float],
lookahead_type: str = "levelset",
) -> None:
"""
A localized look-ahead acquisition function.

Args:
model: The gpytorch model.
target: Threshold value to target in p-space.
"""
super().__init__(model=model)
if lookahead_type == "levelset":
self.lookahead_fn = lookahead_levelset_at_xstar
assert target is not None, "Need a target for levelset lookahead!"
self.gamma = norm.ppf(target)
elif lookahead_type == "posterior":
self.lookahead_fn = lookahead_p_at_xstar
self.gamma = None
else:
raise RuntimeError(f"Got unknown lookahead type {lookahead_type}!")

## Local look-ahead acquisitions
[docs]class LocalLookaheadAcquisitionFunction(LookaheadAcquisitionFunction):
def __init__(
self,
model: GPyTorchModel,
lookahead_type: str = "levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
) -> None:
"""
A localized look-ahead acquisition function.

Args:
model: The gpytorch model.
target: Threshold value to target in p-space.
"""

super().__init__(model=model, target=target, lookahead_type=lookahead_type)
self.posterior_transform = posterior_transform

@t_batch_mode_transform(expected_q=1)
def forward(self, X: Tensor) -> Tensor:
"""
Evaluate acquisition function at X.

Args:
X: (b x 1 x d) point at which to evalaute acquisition function.

Returns: (b) tensor of acquisition values.
"""

Px, P1, P0, py1 = self.lookahead_fn(
model=self.model,
Xstar=X,
Xq=X,
gamma=self.gamma,
posterior_transform=self.posterior_transform,
)  # Return shape here has m=1.
return self._compute_acqf(Px, P1, P0, py1)

def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
raise NotImplementedError

[docs]class LocalMI(LocalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
return MI_fn(Px, P1, P0, py1)

[docs]class LocalSUR(LocalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
return SUR_fn(Px, P1, P0, py1)

@acqf_input_constructor(LocalMI, LocalSUR)
def construct_inputs_local_lookahead(
model: GPyTorchModel,
training_data,
lookahead_type="levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs,
):
return {
"model": model,
"lookahead_type": lookahead_type,
"target": target,
"posterior_transform": posterior_transform,
}

## Global look-ahead acquisitions
[docs]class GlobalLookaheadAcquisitionFunction(LookaheadAcquisitionFunction):
def __init__(
self,
model: GPyTorchModel,
lookahead_type: str = "levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[Tensor] = None,
) -> None:
"""
A global look-ahead acquisition function.

Args:
model: The gpytorch model.
target: Threshold value to target in p-space.
Xq: (m x d) global reference set.
"""
super().__init__(model=model, target=target, lookahead_type=lookahead_type)
self.posterior_transform = posterior_transform
assert (
Xq is not None or query_set_size is not None
), "Must pass either query set size or a query set!"
if Xq is not None and query_set_size is not None:
assert Xq.shape[0] == query_set_size, (
"If passing both Xq and query_set_size,"
+ "first dim of Xq should be query_set_size, got {Xq.shape[0]} != {query_set_size}"
)
if Xq is None:
# cast to an int in case we got a float from Config, which
# would raise on make_scaled_sobol
query_set_size = cast(int, query_set_size)  # make mypy happy
assert int(query_set_size) == query_set_size  # make sure casting is safe
# if the asserts above pass and Xq is None, query_set_size is not None so this is safe
query_set_size = int(query_set_size)  # cast
Xq = make_scaled_sobol(model.lb, model.ub, query_set_size)
self.register_buffer("Xq", Xq)

@t_batch_mode_transform(expected_q=1)
def forward(self, X: Tensor) -> Tensor:
"""
Evaluate acquisition function at X.

Args:
X: (b x 1 x d) point at which to evalaute acquisition function.

Returns: (b) tensor of acquisition values.
"""
Px, P1, P0, py1 = self._get_lookahead_posterior(X)
return self._compute_acqf(Px, P1, P0, py1)

def _get_lookahead_posterior(
self, X: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Xq_batch = self.Xq.expand(X.shape[0], *self.Xq.shape)

return self.lookahead_fn(
model=self.model,
Xstar=X,
Xq=Xq_batch,
gamma=self.gamma,
posterior_transform=self.posterior_transform,
)

def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
raise NotImplementedError

[docs]class GlobalMI(GlobalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
return MI_fn(Px, P1, P0, py1)

[docs]class GlobalSUR(GlobalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
return SUR_fn(Px, P1, P0, py1)

[docs]class ApproxGlobalSUR(GlobalSUR):
def __init__(
self,
model: GPyTorchModel,
lookahead_type="levelset",
target: Optional[float] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[Tensor] = None,
) -> None:
assert (
lookahead_type == "levelset"
), f"ApproxGlobalSUR only supports lookahead on level set, got {lookahead_type}!"
super().__init__(
model=model,
target=target,
lookahead_type=lookahead_type,
query_set_size=query_set_size,
Xq=Xq,
)

def _get_lookahead_posterior(
self, X: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Xq_batch = self.Xq.expand(X.shape[0], *self.Xq.shape)

return approximate_lookahead_levelset_at_xstar(
model=self.model,
Xstar=X,
Xq=Xq_batch,
gamma=self.gamma,
posterior_transform=self.posterior_transform,
)

[docs]class EAVC(GlobalLookaheadAcquisitionFunction):
def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
return EAVC_fn(Px, P1, P0, py1)

[docs]class MOCU(GlobalLookaheadAcquisitionFunction):
"""
MOCU acquisition function given in expr. 4 of:

Zhao, Guang, et al. "Uncertainty-aware active learning for optimal Bayesian classifier."
International Conference on Learning Representations (ICLR) 2021.
"""

def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
current_max_query = torch.maximum(Px, 1 - Px)
# expectation w.r.t. y* of the max of pq
lookahead_pq1_max = torch.maximum(P1, 1 - P1)
lookahead_pq0_max = torch.maximum(P0, 1 - P0)
lookahead_max_query = lookahead_pq1_max * py1 + lookahead_pq0_max * (1 - py1)
return (lookahead_max_query - current_max_query).mean(-1)

[docs]class SMOCU(GlobalLookaheadAcquisitionFunction):
"""
SMOCU acquisition function given in expr. 11 of:

Zhao, Guang, et al. "Bayesian active learning by soft mean objective cost of uncertainty."
International Conference on Artificial Intelligence and Statistics (AISTATS) 2021.
"""

def __init__(
self,
model: GPyTorchModel,
lookahead_type="posterior",
target: Optional[float] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[Tensor] = None,
k: Optional[float] = 20.0,
):

super().__init__(
model=model,
target=target,
lookahead_type=lookahead_type,
query_set_size=query_set_size,
Xq=Xq,
)
self.k = k

def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
current_softmax_query = (
torch.logsumexp(self.k * torch.stack((Px, 1 - Px), dim=-1), dim=-1) / self.k
)
# expectation w.r.t. y* of the max of pq
lookahead_pq1_softmax = (
torch.logsumexp(self.k * torch.stack((P1, 1 - P1), dim=-1), dim=-1) / self.k
)
lookahead_pq0_softmax = (
torch.logsumexp(self.k * torch.stack((P0, 1 - P0), dim=-1), dim=-1) / self.k
)
lookahead_softmax_query = lookahead_pq1_softmax * py1 + lookahead_pq0_softmax * (1 - py1)
return (lookahead_softmax_query - current_softmax_query).mean(-1)

[docs]class BEMPS(GlobalLookaheadAcquisitionFunction):
"""
BEMPS acquisition function given in:

Tan, Wei, et al. "Diversity Enhanced Active Learning with Strictly Proper Scoring Rules."
Advances in Neural Information Processing Systems 34 (2021).
"""

def __init__(self, scorefun, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scorefun = scorefun

def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor:
current_score = self.scorefun(Px)
lookahead_pq1_score = self.scorefun(P1)
lookahead_pq0_score = self.scorefun(P0)
lookahead_expected_score = lookahead_pq1_score * py1 + lookahead_pq0_score * (
1 - py1
)
return (lookahead_expected_score - current_score).mean(-1)

@acqf_input_constructor(GlobalMI, GlobalSUR, ApproxGlobalSUR, EAVC, MOCU, SMOCU, BEMPS)
def construct_inputs_global_lookahead(
model: GPyTorchModel,
training_data,
lookahead_type="levelset",
target: Optional[float] = None,
posterior_transform: Optional[PosteriorTransform] = None,
query_set_size: Optional[int] = 256,
Xq: Optional[Tensor] = None,
**kwargs,
):
lb = [bounds[0] for bounds in kwargs["bounds"]]
ub = [bounds[1] for bounds in kwargs["bounds"]]
Xq = Xq if Xq is not None else make_scaled_sobol(lb, ub, query_set_size)

return {
"model": model,
"lookahead_type": lookahead_type,
"target": target,
"posterior_transform": posterior_transform,
"query_set_size": query_set_size,
"Xq": Xq,
}