#!/usr/bin/env python3

# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional, Tuple

import torch
from botorch.acquisition.objective import PosteriorTransform
from gpytorch.models import GP
from torch import Tensor
from torch.distributions import Normal

from .bvn import bvn_cdf

[docs]def posterior_at_xstar_xq(
model: GP,
Xstar: Tensor,
Xq: Tensor,
posterior_transform: Optional[PosteriorTransform] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Evaluate the posteriors of f at single point Xstar and set of points Xq.

Args:
model: The model to evaluate.
Xstar: (b x 1 x d) tensor.
Xq: (b x m x d) tensor.

Returns:
Mu_s: (b x 1) mean at Xstar.
Sigma2_s: (b x 1) variance at Xstar.
Mu_q: (b x m) mean at Xq.
Sigma2_q: (b x m) variance at Xq.
Sigma_sq: (b x m) covariance between Xstar and each point in Xq.
"""
# Evaluate posterior and extract needed components
Xext = torch.cat((Xstar, Xq), dim=-2)
posterior = model.posterior(Xext, posterior_transform=posterior_transform)
mu = posterior.mean[..., :, 0]
Mu_s = mu[..., 0].unsqueeze(-1)
Mu_q = mu[..., 1:]
Cov = posterior.distribution.covariance_matrix
Sigma2_s = Cov[..., 0, 0].unsqueeze(-1)
Sigma2_q = torch.diagonal(Cov[..., 1:, 1:], dim1=-1, dim2=-2)
Sigma_sq = Cov[..., 0, 1:]
return Mu_s, Sigma2_s, Mu_q, Sigma2_q, Sigma_sq

model: GP,
Xstar: Tensor,
Xq: Tensor,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Dict[str, Any],
):
"""
Evaluate the look-ahead level-set posterior at Xq given observation at xstar.

Args:
model: The model to evaluate.
Xstar: (b x 1 x d) observation point.
Xq: (b x m x d) reference points.
gamma: Threshold in f-space.

Returns:
Px: (b x m) Level-set posterior at Xq, before observation at xstar.
P1: (b x m) Level-set posterior at Xq, given observation of 1 at xstar.
P0: (b x m) Level-set posterior at Xq, given observation of 0 at xstar.
py1: (b x 1) Probability of observing 1 at xstar.
"""
Mu_s, Sigma2_s, Mu_q, Sigma2_q, Sigma_sq = posterior_at_xstar_xq(
model=model, Xstar=Xstar, Xq=Xq, posterior_transform=posterior_transform
)

try:
gamma = kwargs.get("gamma")
except KeyError:

Norm = torch.distributions.Normal(0, 1)
Sigma_q = torch.sqrt(Sigma2_q)
b_q = (gamma - Mu_q) / Sigma_q
Phi_bq = Norm.cdf(b_q)
denom = torch.sqrt(1 + Sigma2_s)
a_s = Mu_s / denom
Phi_as = Norm.cdf(a_s)
Z_rho = -Sigma_sq / (Sigma_q * denom)
Z_qs = bvn_cdf(a_s, b_q, Z_rho)

Px = Phi_bq
py1 = Phi_as
P1 = Z_qs / py1
P0 = (Phi_bq - Z_qs) / (1 - py1)
return Px, P1, P0, py1

model: GP,
Xstar: Tensor,
Xq: Tensor,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Dict[str, Any],
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Evaluate the look-ahead response probability posterior at Xq given observation at xstar.

Uses the approximation given in expr. 9 in:
Zhao, Guang, et al. "Efficient active learning for Gaussian process classification by
error reduction." Advances in Neural Information Processing Systems 34 (2021): 9734-9746.

Args:
model: The model to evaluate.
Xstar: (b x 1 x d) observation point.
Xq: (b x m x d) reference points.
kwargs: ignored (here for compatibility with other kinds of lookahead)

Returns:
Px: (b x m) Response posterior at Xq, before observation at xstar.
P1: (b x m) Response posterior at Xq, given observation of 1 at xstar.
P0: (b x m) Response posterior at Xq, given observation of 0 at xstar.
py1: (b x 1) Probability of observing 1 at xstar.
"""
Mu_s, Sigma2_s, Mu_q, Sigma2_q, Sigma_sq = posterior_at_xstar_xq(
model=model, Xstar=Xstar, Xq=Xq, posterior_transform=posterior_transform
)

probit = Normal(0, 1).cdf

mu_tilde_star = Mu_s + (f_q - Mu_q) * Sigma_sq / Sigma2_q
sigma_tilde_star = Sigma2_s - (Sigma_sq**2) / Sigma2_q
return probit(mu_tilde_star / torch.sqrt(sigma_tilde_star + 1)) * probit(f_q)

pstar_marginal_1 = probit(Mu_s / torch.sqrt(1 + Sigma2_s))
pstar_marginal_0 = 1 - pstar_marginal_1
pq_marginal_1 = probit(Mu_q / torch.sqrt(1 + Sigma2_q))

fq_mvn = Normal(Mu_q, torch.sqrt(Sigma2_q))
joint_ystar0_yq1 = pq_marginal_1 - joint_ystar1_yq1

# now we need from the joint to the marginal on xq

model: GP,
Xstar: Tensor,
Xq: Tensor,
gamma: float,
posterior_transform: Optional[PosteriorTransform] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
The look-ahead posterior approximation of Lyu et al.

Args:
model: The model to evaluate.
Xstar: (b x 1 x d) observation point.
Xq: (b x m x d) reference points.
gamma: Threshold in f-space.

Returns:
Px: (b x m) Level-set posterior at Xq, before observation at xstar.
P1: (b x m) Level-set posterior at Xq, given observation of 1 at xstar.
P0: (b x m) Level-set posterior at Xq, given observation of 0 at xstar.
py1: (b x 1) Probability of observing 1 at xstar.
"""
Mu_s, Sigma2_s, Mu_q, Sigma2_q, Sigma_sq = posterior_at_xstar_xq(
model=model, Xstar=Xstar, Xq=Xq, posterior_transform=posterior_transform
)

Norm = torch.distributions.Normal(0, 1)
Mu_s_pdf = torch.exp(Norm.log_prob(Mu_s))
Mu_s_cdf = Norm.cdf(Mu_s)

# Formulae from the supplement of the paper (Result 2)
vnp1_p = Mu_s_pdf**2 / Mu_s_cdf**2 + Mu_s * Mu_s_pdf / Mu_s_cdf  # (C.4)
p_p = Norm.cdf(Mu_s / torch.sqrt(1 + Sigma2_s))  # (C.5)

vnp1_n = Mu_s_pdf**2 / (1 - Mu_s_cdf) ** 2 - Mu_s * Mu_s_pdf / (
1 - Mu_s_cdf
)  # (C.6)
p_n = 1 - p_p  # (C.7)

vtild = vnp1_p * p_p + vnp1_n * p_n

Sigma2_q_np1 = Sigma2_q - Sigma_sq**2 / ((1 / vtild) + Sigma2_s)  # (C.8)

Px = Norm.cdf((gamma - Mu_q) / torch.sqrt(Sigma2_q))
P1 = Norm.cdf((gamma - Mu_q) / torch.sqrt(Sigma2_q_np1))
P0 = P1  # Same because we ignore value of y in this approximation
py1 = 0.5 * torch.ones(*Px.shape[:-1], 1)  # Value doesn't matter because P1 = P0
return Px, P1, P0, py1