Source code for aepsych.likelihoods.bernoulli

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Optional

import torch
from aepsych.config import ConfigurableMixin
from gpytorch.likelihoods import _OneDimensionalLikelihood


[docs]class BernoulliObjectiveLikelihood(_OneDimensionalLikelihood, ConfigurableMixin): """ Bernoulli likelihood with a flexible link (objective) defined by a callable (which can be a botorch objective) """ def __init__(self, objective: Callable) -> None: """Initialize BernoulliObjectiveLikelihood. Args: objective (Callable): Objective function that maps function samples to probabilities. """ super().__init__() self.objective = objective
[docs] def forward( self, function_samples: torch.Tensor, **kwargs: Any ) -> torch.distributions.Bernoulli: """Forward pass for BernoulliObjectiveLikelihood. Args: function_samples (torch.Tensor): Function samples. Returns: torch.distributions.Bernoulli: Bernoulli distribution object. """ output_probs = self.objective(function_samples) return torch.distributions.Bernoulli(probs=output_probs)
[docs] @classmethod def get_config_options( cls, config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Find the config options for the likelihood. Args: config (Config): Config to look for options in. name (str, optional): Unused, kept for API conformity. options (Dict[str, Any], optional): Existing options, any key in options will be ignored from the config. Return: Dict[str, Any]: A dictionary of options to initialize the likelihood. """ options = super().get_config_options(config, name, options) options["objective"] = options["objective"].from_config(config) return options