Source code for aepsych.models.model_list

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple, Union

import numpy as np
import torch
from aepsych.models.base import AEPsychModel
from botorch.models import ModelListGP


[docs]class AEPsychModelListGP(AEPsychModel, ModelListGP):
[docs] def fit(self): for model in self.models: model.fit()
[docs] def predict_probability( self, x: Union[torch.Tensor, np.ndarray] ) -> Tuple[torch.Tensor, torch.Tensor]: """Query the model for posterior mean and variance in probability space. This method works by calling `predict_probability` separately for each model in self.models. If a model does not implement "predict_probability", it will instead return `model.predict`. Args: x (torch.Tensor): Points at which to predict from the model. Returns: Tuple[np.ndarray, np.ndarray]: Posterior mean and variance at queries points. """ prob_list = [] vars_list = [] for model in self.models: if hasattr(model, "predict_probability"): prob, var = model.predict_probability(x) else: prob, var = model.predict(x) prob_list.append(prob.unsqueeze(-1)) vars_list.append(var.unsqueeze(-1)) probs = torch.hstack(prob_list) vars = torch.hstack(vars_list) return probs, vars
[docs] @classmethod def get_mll_class(cls): return None