Source code for aepsych.likelihoods.bernoulli
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable
import torch
from aepsych.config import Config
from gpytorch.likelihoods import _OneDimensionalLikelihood
[docs]class BernoulliObjectiveLikelihood(_OneDimensionalLikelihood):
"""
Bernoulli likelihood with a flexible link (objective) defined
by a callable (which can be a botorch objective)
"""
def __init__(self, objective: Callable) -> None:
"""Initialize BernoulliObjectiveLikelihood.
Args:
objective (Callable): Objective function that maps function samples to probabilities."""
super().__init__()
self.objective = objective
[docs] def forward(
self, function_samples: torch.Tensor, **kwargs: Any
) -> torch.distributions.Bernoulli:
"""Forward pass for BernoulliObjectiveLikelihood.
Args:
function_samples (torch.Tensor): Function samples.
Returns:
torch.distributions.Bernoulli: Bernoulli distribution object.
"""
output_probs = self.objective(function_samples)
return torch.distributions.Bernoulli(probs=output_probs)
[docs] @classmethod
def from_config(cls, config: Config) -> "BernoulliObjectiveLikelihood":
"""Create an instance from a configuration object.
Args:
config (Config): Configuration object.
Returns:
BernoulliObjectiveLikelihood: BernoulliObjectiveLikelihood instance.
"""
objective_cls = config.getobj(cls.__name__, "objective")
objective = objective_cls.from_config(config)
return cls(objective=objective)