Source code for aepsych.acquisition.lookahead

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, cast

import numpy as np
import torch
from aepsych.utils import make_scaled_sobol
from botorch.acquisition import AcquisitionFunction
from botorch.acquisition.input_constructors import acqf_input_constructor
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.gpytorch import GPyTorchModel
from botorch.utils.transforms import t_batch_mode_transform
from scipy.stats import norm
from torch import Tensor

from .lookahead_utils import (
    approximate_lookahead_levelset_at_xstar,
    lookahead_levelset_at_xstar,
    lookahead_p_at_xstar,
)


[docs]def Hb(p: Tensor): """ Binary entropy. Args: p: Tensor of probabilities. Returns: Binary entropy for each probability. """ epsilon = torch.tensor(np.finfo(float).eps) p = torch.clamp(p, min=epsilon, max=1 - epsilon) return -torch.nan_to_num(p * torch.log2(p) + (1 - p) * torch.log2(1 - p))
[docs]def MI_fn(Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: """ Average mutual information. H(p) - E_y*[H(p | y*)] Args: Px: (b x m) Level-set posterior before observation P1: (b x m) Level-set posterior given observation of 1 P0: (b x m) Level-set posterior given observation of 0 py1: (b x 1) Probability of observing 1 Returns: (b) tensor of mutual information averaged over Xq. """ mi = Hb(Px) - py1 * Hb(P1) - (1 - py1) * Hb(P0) return mi.sum(dim=-1)
[docs]def ClassErr(p: Tensor) -> Tensor: """ Expected classification error, min(p, 1-p). """ return torch.min(p, 1 - p)
[docs]def SUR_fn(Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: """ Stepwise uncertainty reduction. Expected reduction in expected classification error given observation at Xstar, averaged over Xq. Args: Px: (b x m) Level-set posterior before observation P1: (b x m) Level-set posterior given observation of 1 P0: (b x m) Level-set posterior given observation of 0 py1: (b x 1) Probability of observing 1 Returns: (b) tensor of SUR values. """ sur = ClassErr(Px) - py1 * ClassErr(P1) - (1 - py1) * ClassErr(P0) return sur.sum(dim=-1)
[docs]def EAVC_fn(Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: """ Expected absolute value change. Expected absolute change in expected level-set volume given observation at Xstar. Args: Px: (b x m) Level-set posterior before observation P1: (b x m) Level-set posterior given observation of 1 P0: (b x m) Level-set posterior given observation of 0 py1: (b x 1) Probability of observing 1 Returns: (b) tensor of EAVC values. """ avc1 = torch.abs((Px - P1).sum(dim=-1)) avc0 = torch.abs((Px - P0).sum(dim=-1)) return py1.squeeze(-1) * avc1 + (1 - py1).squeeze(-1) * avc0
[docs]class LookaheadAcquisitionFunction(AcquisitionFunction): def __init__( self, model: GPyTorchModel, target: Optional[float], lookahead_type: str = "levelset", ) -> None: """ A localized look-ahead acquisition function. Args: model: The gpytorch model. target: Threshold value to target in p-space. """ super().__init__(model=model) if lookahead_type == "levelset": self.lookahead_fn = lookahead_levelset_at_xstar assert target is not None, "Need a target for levelset lookahead!" self.gamma = norm.ppf(target) elif lookahead_type == "posterior": self.lookahead_fn = lookahead_p_at_xstar self.gamma = None else: raise RuntimeError(f"Got unknown lookahead type {lookahead_type}!")
## Local look-ahead acquisitions
[docs]class LocalLookaheadAcquisitionFunction(LookaheadAcquisitionFunction): def __init__( self, model: GPyTorchModel, lookahead_type: str = "levelset", target: Optional[float] = None, posterior_transform: Optional[PosteriorTransform] = None, ) -> None: """ A localized look-ahead acquisition function. Args: model: The gpytorch model. target: Threshold value to target in p-space. """ super().__init__(model=model, target=target, lookahead_type=lookahead_type) self.posterior_transform = posterior_transform @t_batch_mode_transform(expected_q=1) def forward(self, X: Tensor) -> Tensor: """ Evaluate acquisition function at X. Args: X: (b x 1 x d) point at which to evalaute acquisition function. Returns: (b) tensor of acquisition values. """ Px, P1, P0, py1 = self.lookahead_fn( model=self.model, Xstar=X, Xq=X, gamma=self.gamma, posterior_transform=self.posterior_transform, ) # Return shape here has m=1. return self._compute_acqf(Px, P1, P0, py1) def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: raise NotImplementedError
[docs]class LocalMI(LocalLookaheadAcquisitionFunction): def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: return MI_fn(Px, P1, P0, py1)
[docs]class LocalSUR(LocalLookaheadAcquisitionFunction): def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: return SUR_fn(Px, P1, P0, py1)
@acqf_input_constructor(LocalMI, LocalSUR) def construct_inputs_local_lookahead( model: GPyTorchModel, training_data, lookahead_type="levelset", target: Optional[float] = None, posterior_transform: Optional[PosteriorTransform] = None, **kwargs, ): return { "model": model, "lookahead_type": lookahead_type, "target": target, "posterior_transform": posterior_transform, } ## Global look-ahead acquisitions
[docs]class GlobalLookaheadAcquisitionFunction(LookaheadAcquisitionFunction): def __init__( self, model: GPyTorchModel, lookahead_type: str = "levelset", target: Optional[float] = None, posterior_transform: Optional[PosteriorTransform] = None, query_set_size: Optional[int] = 256, Xq: Optional[Tensor] = None, ) -> None: """ A global look-ahead acquisition function. Args: model: The gpytorch model. target: Threshold value to target in p-space. Xq: (m x d) global reference set. """ super().__init__(model=model, target=target, lookahead_type=lookahead_type) self.posterior_transform = posterior_transform assert ( Xq is not None or query_set_size is not None ), "Must pass either query set size or a query set!" if Xq is not None and query_set_size is not None: assert Xq.shape[0] == query_set_size, ( "If passing both Xq and query_set_size," + "first dim of Xq should be query_set_size, got {Xq.shape[0]} != {query_set_size}" ) if Xq is None: # cast to an int in case we got a float from Config, which # would raise on make_scaled_sobol query_set_size = cast(int, query_set_size) # make mypy happy assert int(query_set_size) == query_set_size # make sure casting is safe # if the asserts above pass and Xq is None, query_set_size is not None so this is safe query_set_size = int(query_set_size) # cast Xq = make_scaled_sobol(model.lb, model.ub, query_set_size) self.register_buffer("Xq", Xq) @t_batch_mode_transform(expected_q=1) def forward(self, X: Tensor) -> Tensor: """ Evaluate acquisition function at X. Args: X: (b x 1 x d) point at which to evalaute acquisition function. Returns: (b) tensor of acquisition values. """ Px, P1, P0, py1 = self._get_lookahead_posterior(X) return self._compute_acqf(Px, P1, P0, py1) def _get_lookahead_posterior( self, X: Tensor ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: Xq_batch = self.Xq.expand(X.shape[0], *self.Xq.shape) return self.lookahead_fn( model=self.model, Xstar=X, Xq=Xq_batch, gamma=self.gamma, posterior_transform=self.posterior_transform, ) def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: raise NotImplementedError
[docs]class GlobalMI(GlobalLookaheadAcquisitionFunction): def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: return MI_fn(Px, P1, P0, py1)
[docs]class GlobalSUR(GlobalLookaheadAcquisitionFunction): def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: return SUR_fn(Px, P1, P0, py1)
[docs]class ApproxGlobalSUR(GlobalSUR): def __init__( self, model: GPyTorchModel, lookahead_type="levelset", target: Optional[float] = None, query_set_size: Optional[int] = 256, Xq: Optional[Tensor] = None, ) -> None: assert ( lookahead_type == "levelset" ), f"ApproxGlobalSUR only supports lookahead on level set, got {lookahead_type}!" super().__init__( model=model, target=target, lookahead_type=lookahead_type, query_set_size=query_set_size, Xq=Xq, ) def _get_lookahead_posterior( self, X: Tensor ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: Xq_batch = self.Xq.expand(X.shape[0], *self.Xq.shape) return approximate_lookahead_levelset_at_xstar( model=self.model, Xstar=X, Xq=Xq_batch, gamma=self.gamma, posterior_transform=self.posterior_transform, )
[docs]class EAVC(GlobalLookaheadAcquisitionFunction): def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: return EAVC_fn(Px, P1, P0, py1)
[docs]class MOCU(GlobalLookaheadAcquisitionFunction): """ MOCU acquisition function given in expr. 4 of: Zhao, Guang, et al. "Uncertainty-aware active learning for optimal Bayesian classifier." International Conference on Learning Representations (ICLR) 2021. """ def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: current_max_query = torch.maximum(Px, 1 - Px) # expectation w.r.t. y* of the max of pq lookahead_pq1_max = torch.maximum(P1, 1 - P1) lookahead_pq0_max = torch.maximum(P0, 1 - P0) lookahead_max_query = lookahead_pq1_max * py1 + lookahead_pq0_max * (1 - py1) return (lookahead_max_query - current_max_query).mean(-1)
[docs]class SMOCU(GlobalLookaheadAcquisitionFunction): """ SMOCU acquisition function given in expr. 11 of: Zhao, Guang, et al. "Bayesian active learning by soft mean objective cost of uncertainty." International Conference on Artificial Intelligence and Statistics (AISTATS) 2021. """ def __init__(self, k, *args, **kwargs): super().__init__(*args, **kwargs) self.k = k def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: stacked = torch.stack((Px, 1 - Px), dim=-1) current_softmax_query = torch.logsumexp(self.k * stacked, dim=-1) / self.k # expectation w.r.t. y* of the max of pq lookahead_pq1_max = torch.maximum(P1, 1 - P1) lookahead_pq0_max = torch.maximum(P0, 1 - P0) lookahead_max_query = lookahead_pq1_max * py1 + lookahead_pq0_max * (1 - py1) return (lookahead_max_query - current_softmax_query).mean(-1)
[docs]class BEMPS(GlobalLookaheadAcquisitionFunction): """ BEMPS acquisition function given in: Tan, Wei, et al. "Diversity Enhanced Active Learning with Strictly Proper Scoring Rules." Advances in Neural Information Processing Systems 34 (2021). """ def __init__(self, scorefun, *args, **kwargs): super().__init__(*args, **kwargs) self.scorefun = scorefun def _compute_acqf(self, Px: Tensor, P1: Tensor, P0: Tensor, py1: Tensor) -> Tensor: current_score = self.scorefun(Px) lookahead_pq1_score = self.scorefun(P1) lookahead_pq0_score = self.scorefun(P0) lookahead_expected_score = lookahead_pq1_score * py1 + lookahead_pq0_score * ( 1 - py1 ) return (lookahead_expected_score - current_score).mean(-1)
@acqf_input_constructor(GlobalMI, GlobalSUR, ApproxGlobalSUR, EAVC, MOCU, SMOCU, BEMPS) def construct_inputs_global_lookahead( model: GPyTorchModel, training_data, lookahead_type="levelset", target: Optional[float] = None, posterior_transform: Optional[PosteriorTransform] = None, query_set_size: Optional[int] = 256, Xq: Optional[Tensor] = None, **kwargs, ): lb = [bounds[0] for bounds in kwargs["bounds"]] ub = [bounds[1] for bounds in kwargs["bounds"]] Xq = Xq if Xq is not None else make_scaled_sobol(lb, ub, query_set_size) return { "model": model, "lookahead_type": lookahead_type, "target": target, "posterior_transform": posterior_transform, "query_set_size": query_set_size, "Xq": Xq, }