Source code for aepsych.acquisition.objective.semi_p

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional

import torch
from aepsych.config import Config
from aepsych.likelihoods import LinearBernoulliLikelihood
from botorch.acquisition.objective import MCAcquisitionObjective
from gpytorch.likelihoods import Likelihood
from torch import Tensor


class SemiPObjectiveBase(MCAcquisitionObjective):
    """Wraps the semi-parametric transform into an objective
    that correctly extracts various things
    """

    # because we have an extra dim for the SemiP batch dimension,
    # all the q-batch output shape checks fail, disable them here
    _verify_output_shape: bool = False

    def __init__(self, stim_dim: int = 0):
        super().__init__()
        self.stim_dim = stim_dim


[docs]class SemiPProbabilityObjective(SemiPObjectiveBase): """Wraps the semi-parametric transform into an objective that gives outcome probabilities """ def __init__(self, likelihood: Likelihood = None, *args, **kwargs): """Evaluates the probability objective. Args: likelihood (Likelihood). Underlying SemiP likelihood (which we use for its objective/link) other arguments are passed to the base class (notably, stim_dim). """ super().__init__(*args, **kwargs) self.likelihood = likelihood or LinearBernoulliLikelihood()
[docs] def forward(self, samples: Tensor, X: Tensor) -> Tensor: """Evaluates the probability objective. Args: samples (Tensor): GP samples. X (Tensor): Inputs at which to evaluate objective. Unlike most AEPsych objectives, we need X here to split out the intensity dimension. Returns: Tensor: Response probabilities at the specific X values and function samples. """ Xi = X[..., self.stim_dim] # the output of LinearBernoulliLikelihood is (nsamp x b x n x 1) # but the output of MCAcquisitionObjective should be `nsamp x *batch_shape x q` # so we remove the final dim return self.likelihood.p(function_samples=samples, Xi=Xi).squeeze(-1)
[docs] @classmethod def from_config(cls, config: Config): classname = cls.__name__ likelihood_cls = config.getobj(classname, "likelihood", fallback=None) if likelihood_cls is not None: if hasattr(likelihood_cls, "from_config"): likelihood = likelihood_cls.from_config(config) else: likelihood = likelihood_cls() else: likelihood = None # fall back to __init__ default return cls(likelihood=likelihood)
[docs]class SemiPThresholdObjective(SemiPObjectiveBase): """Wraps the semi-parametric transform into an objective that gives the threshold distribution. """ def __init__(self, target: float, likelihood=None, *args, **kwargs): """Evaluates the probability objective. Args: target (float): the threshold to evaluate. likelihood (Likelihood): Underlying SemiP likelihood (which we use for its inverse link) other arguments are passed to the base class (notably, stim_dim). """ super().__init__(*args, **kwargs) self.likelihood = likelihood or LinearBernoulliLikelihood() self.fspace_target = self.likelihood.objective.inverse(torch.tensor(target))
[docs] def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor: """Evaluates the probability objective. Args: samples (Tensor): GP samples. X (Tensor, optional): Ignored, here for compatibility with the objective API. Returns: Tensor: Threshold probabilities at the specific GP sample values. """ offset = samples[..., 0, :] slope = samples[..., 1, :] return (self.fspace_target + slope * offset) / slope
[docs] @classmethod def from_config(cls, config: Config): classname = cls.__name__ likelihood_cls = config.getobj(classname, "likelihood", fallback=None) if likelihood_cls is not None: if hasattr(likelihood_cls, "from_config"): likelihood = likelihood_cls.from_config(config) else: likelihood = likelihood_cls() else: likelihood = None # fall back to __init__ default target = config.getfloat(classname, "target", fallback=0.75) return cls(likelihood=likelihood, target=target)