Source code for aepsych.generators.base

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import abc
from inspect import signature
from typing import Any, Dict, Generic, Protocol, runtime_checkable, TypeVar, Optional
import re

import torch
from aepsych.config import Config
from aepsych.models.base import AEPsychMixin
from botorch.acquisition import (
    AcquisitionFunction,
    NoisyExpectedImprovement,
    qNoisyExpectedImprovement,
    LogNoisyExpectedImprovement,
    qLogNoisyExpectedImprovement,
)


AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychMixin)


[docs]@runtime_checkable class AcqArgProtocol(Protocol):
[docs] @classmethod def from_config(cls, config: Config) -> Any: pass
[docs]class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]): """Abstract base class for generators, which are responsible for generating new points to sample.""" _requires_model = True baseline_requiring_acqfs = [ qNoisyExpectedImprovement, NoisyExpectedImprovement, qLogNoisyExpectedImprovement, LogNoisyExpectedImprovement, ] stimuli_per_trial = 1 max_asks: Optional[int] = None def __init__( self, ) -> None: pass
[docs] @abc.abstractmethod def gen(self, num_points: int, model: AEPsychModelType) -> torch.Tensor: pass
[docs] @classmethod @abc.abstractmethod def from_config(cls, config: Config): pass
@classmethod def _get_acqf_options(cls, acqf: AcquisitionFunction, config: Config): if acqf is not None: acqf_name = acqf.__name__ # model is not an extra arg, it's a default arg acqf_args_expected = [ i for i in list(signature(acqf).parameters.keys()) if i != "model" ] # this is still very ugly extra_acqf_args = {} if acqf_name in config: full_section = config[acqf_name] for k in acqf_args_expected: # if this thing is configured if k in full_section.keys(): v = config.get(acqf_name, k) # if it's an object make it an object if full_section[k] in Config.registered_names.keys(): extra_acqf_args[k] = config.getobj(acqf_name, k) elif re.search( r"^\[.*\]$", v, flags=re.DOTALL ): # use regex to check if the value is a list extra_acqf_args[k] = config._str_to_list(v) # type: ignore else: # otherwise try a float try: extra_acqf_args[k] = config.getfloat(acqf_name, k) # finally just return a string except ValueError: extra_acqf_args[k] = config.get(acqf_name, k) # next, do more processing for k, v in extra_acqf_args.items(): if hasattr(v, "from_config"): # configure if needed assert isinstance(v, AcqArgProtocol) # make mypy happy extra_acqf_args[k] = v.from_config(config) elif isinstance(v, type): # instaniate a class if needed extra_acqf_args[k] = v() else: extra_acqf_args = {} return extra_acqf_args