Source code for aepsych.generators.semi_p

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Type

import torch
from aepsych.acquisition.objective.semi_p import SemiPThresholdObjective
from aepsych.generators import OptimizeAcqfGenerator
from aepsych.models.semi_p import SemiParametricGPModel


[docs]class IntensityAwareSemiPGenerator(OptimizeAcqfGenerator): """Generator for SemiP. With botorch machinery, in order to optimize acquisition separately over context and intensity, we need two ingredients. 1. An objective that samples from some posterior w.r.t. the context. From the paper, this is ThresholdBALV and needs the threshold posterior. `SemiPThresholdObjective` implements this for ThresholdBALV but theoretically this can be any subclass of `SemiPObjectiveBase`. 2. A way to do acquisition over context and intensity separately, which is provided by this class. We optimize the acquisition function over context dimensions, then conditioned on the optimum we evaluate the intensity at the objective to obtain the intensity value. We only developed ThresholdBALV that is specific to SemiP, which is what we tested with this generator. It should work with other similar acquisition functions. """
[docs] def gen( # type: ignore[override] self, num_points: int, model: SemiParametricGPModel, # type: ignore[override] context_objective: Type = SemiPThresholdObjective, ) -> torch.Tensor: fixed_features = {model.stim_dim: 0} next_x = super().gen( num_points=num_points, model=model, fixed_features=fixed_features ) # to compute intensity, we need the point where f is at the # threshold as a function of context. self.acqf_kwargs should contain # remaining objective args (like threshold target value) thresh_objective = context_objective( likelihood=model.likelihood, stim_dim=model.stim_dim, **self.acqf_kwargs ) kc_mean_at_best_context = model(torch.Tensor(next_x)).mean thresh_at_best_context = thresh_objective(kc_mean_at_best_context) thresh_at_best_context = torch.clamp( thresh_at_best_context, min=model.lb[model.stim_dim], max=model.ub[model.stim_dim], ) next_x[..., model.stim_dim] = thresh_at_best_context.detach() return next_x