Source code for aepsych.likelihoods.semi_p
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
import torch
from aepsych.acquisition.objective import AEPsychObjective, FloorProbitObjective
from aepsych.config import Config
from gpytorch.likelihoods import _OneDimensionalLikelihood
[docs]class LinearBernoulliLikelihood(_OneDimensionalLikelihood):
"""
A likelihood of the form Bernoulli(sigma(k(x+c))), where k and c are
GPs and sigma is a flexible link function.
"""
def __init__(self, objective: Optional[AEPsychObjective] = None) -> None:
"""Initializes the linear bernoulli likelihood.
Args:
objective (AEPsychObjective, optional): Link function to use (sigma in the notation above).
Defaults to probit with no floor.
"""
super().__init__()
self.objective = objective or FloorProbitObjective(floor=0.0)
[docs] def f(self, function_samples: torch.Tensor, Xi: torch.Tensor) -> torch.Tensor:
"""Return the latent function value, k(x-c).
Args:
function_samples (torch.Tensor): Samples from a batched GP
Xi (torch.Tensor): Intensity values.
Returns:
torch.Tensor: latent function value.
"""
# function_samples is of shape nsamp x (b) x 2 x n
# If (b) is present,
if function_samples.ndim > 3:
assert function_samples.ndim == 4
assert function_samples.shape[2] == 2
# In this case, Xi will be of size b x n
# Offset and slope should be num_samps x b x n
offset = function_samples[:, :, 0, :]
slope = function_samples[:, :, 1, :]
fsamps = slope * (Xi - offset)
# Expand from (nsamp x b x n) to (nsamp x b x n x 1)
fsamps = fsamps.unsqueeze(-1)
else:
assert function_samples.ndim == 3
assert function_samples.shape[1] == 2
# Shape is num_samps x 2 x n
# Offset and slope should be num_samps x n
# Xi will be of size n
offset = function_samples[:, 0, :]
slope = function_samples[:, 1, :]
fsamps = slope * (Xi - offset)
# Expand from (nsamp x n) to (nsamp x 1 x n x 1)
fsamps = fsamps.unsqueeze(1).unsqueeze(-1)
return fsamps
[docs] def p(self, function_samples: torch.Tensor, Xi: torch.Tensor) -> torch.Tensor:
"""Returns the response probability sigma(k(x+c)).
Args:
function_samples (torch.Tensor): Samples from the batched GP (see documentation for self.f)
Xi (torch.Tensor): Intensity Values.
Returns:
torch.Tensor: Response probabilities.
"""
fsamps = self.f(function_samples, Xi)
return self.objective(fsamps)
[docs] def forward(
self, function_samples: torch.Tensor, Xi: torch.Tensor, **kwargs
) -> torch.distributions.Bernoulli:
"""Forward pass for the likelihood
Args:
function_samples (torch.Tensor): Samples from a batched GP of batch size 2.
Xi (torch.Tensor): Intensity values.
Returns:
torch.distributions.Bernoulli: Outcome likelihood.
"""
output_probs = self.p(function_samples, Xi)
return torch.distributions.Bernoulli(probs=output_probs)
[docs] def expected_log_prob(
self, observations: torch.Tensor, function_dist: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""This has to be overridden to fix a bug in gpytorch where the kwargs
aren't being passed along to self.forward.
Args:
observations (torch.Tensor): Observations.
function_dist (torch.Tensor): Function distribution.
Returns:
torch.Tensor: Expected log probability.
"""
# modified, TODO fixme upstream (cc @bletham)
def log_prob_lambda(function_samples: torch.Tensor) -> torch.Tensor:
"""Lambda function to compute the log probability.
Args:
function_samples (torch.Tensor): Function samples.
Returns:
torch.Tensor: Log probability.
"""
return self.forward(function_samples, **kwargs).log_prob(observations)
log_prob = self.quadrature(log_prob_lambda, function_dist)
return log_prob
[docs] @classmethod
def from_config(cls, config: Config) -> "LinearBernoulliLikelihood":
"""Create an instance from a configuration object.
Args:
config (Config): Configuration object.
Returns:
LinearBernoulliLikelihood: LinearBernoulliLikelihood instance.
"""
classname = cls.__name__
objective = config.getobj(classname, "objective")
if hasattr(objective, "from_config"):
objective = objective.from_config(config)
else:
objective = objective
return cls(objective=objective)