#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import abc
import re
import warnings
from inspect import _empty, signature
from typing import Any, Dict, Generic, Optional, Protocol, runtime_checkable, TypeVar
import torch
from aepsych.config import Config, ConfigurableMixin
from aepsych.models.base import AEPsychMixin
from botorch.acquisition import (
AcquisitionFunction,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
qNoisyExpectedImprovement,
)
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption
from ..models.base import ModelProtocol
AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychMixin)
[docs]@runtime_checkable
class AcqArgProtocol(Protocol):
[docs] @classmethod
def from_config(cls, config: Config) -> Any:
pass
[docs]class AEPsychGenerator(abc.ABC, Generic[AEPsychModelType]):
"""Abstract base class for generators, which are responsible for generating new points to sample."""
_requires_model = True
stimuli_per_trial = 1
max_asks: Optional[int] = None
dim: int
def __init__(
self,
) -> None:
pass
[docs] @abc.abstractmethod
def gen(
self,
num_points: int,
model: AEPsychModelType,
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
pass
[docs]class AcqfGenerator(AEPsychGenerator, ConfigurableMixin):
"""Base class for generators that evaluate acquisition functions."""
_requires_model = True
baseline_requiring_acqfs = [
qNoisyExpectedImprovement,
NoisyExpectedImprovement,
qLogNoisyExpectedImprovement,
LogNoisyExpectedImprovement,
]
acqf: AcquisitionFunction
acqf_kwargs: Dict[str, Any]
def __init__(
self,
acqf: AcquisitionFunction,
acqf_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.acqf = acqf
if acqf_kwargs is None:
acqf_kwargs = {}
self.acqf_kwargs = acqf_kwargs
@classmethod
def _get_acqf_options(
cls, acqf: AcquisitionFunction, config: Config
) -> Dict[str, Any]:
"""Get the extra arguments for the acquisition function from the config.
Args:
acqf (AcquisitionFunction): The acquisition function to get arguments for.
config (Config): The configuration object.
Returns:
Dict[str, Any]: The extra arguments for the acquisition function.
"""
if acqf is not None:
acqf_name = acqf.__name__
# model is not an extra arg, it's a default arg
acqf_kwargs = signature(acqf).parameters
acqf_args_expected = [i for i in list(acqf_kwargs.keys()) if i != "model"]
# this is still very ugly
extra_acqf_args = {}
if acqf_name in config:
full_section = config[acqf_name]
for k in acqf_args_expected:
# if this thing is configured
if k in full_section.keys():
v = config.get(acqf_name, k)
# if it's an object make it an object
if full_section[k] in Config.registered_names.keys():
extra_acqf_args[k] = config.getobj(acqf_name, k)
elif re.search(
r"^\[.*\]$", v, flags=re.DOTALL
): # use regex to check if the value is a list
extra_acqf_args[k] = config._str_to_list(v) # type: ignore
else:
# otherwise try a float
try:
extra_acqf_args[k] = config.getfloat(acqf_name, k)
# finally just return a string
except ValueError:
extra_acqf_args[k] = config.get(acqf_name, k)
# next, do more processing
for k, v in extra_acqf_args.items():
if hasattr(v, "from_config"): # configure if needed
assert isinstance(v, AcqArgProtocol) # make mypy happy
extra_acqf_args[k] = v.from_config(config)
elif isinstance(v, type): # instaniate a class if needed
extra_acqf_args[k] = v()
# Final checks, bandaid
for key, value in acqf_kwargs.items():
if key == "model": # Model is handled separately
continue
if value.default == _empty and key not in extra_acqf_args:
if key not in config["common"]:
# HACK: Not actually sure why some required args can be missing
warnings.warn(
f"{acqf_name} requires the {key} option but we could not find it.",
UserWarning,
)
continue
# A required parameter is missing! Look for it in common
config_str = config.get("common", key)
# An object that we know about
if config_str in Config.registered_names.keys():
extra_acqf_args[key] = config.getobj("common", key)
# Some sequence
if "[" in config_str and "]" in config_str:
# Try to turn it into a tensor or fallback to list
try:
extra_acqf_args[key] = config.gettensor("common", key)
except ValueError:
extra_acqf_args[key] = config.getlist("common", key)
else:
extra_acqf_args = {}
return extra_acqf_args
def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFunction:
"""
Instantiates the acquisition function with the specified model and additional arguments.
Args:
model (ModelProtocol): The model to use with the acquisition function.
Returns:
AcquisitionFunction: Configured acquisition function.
"""
if self.acqf == AnalyticExpectedUtilityOfBestOption:
return self.acqf(pref_model=model)
if hasattr(model, "device"):
if "lb" in self.acqf_kwargs:
if not isinstance(self.acqf_kwargs["lb"], torch.Tensor):
self.acqf_kwargs["lb"] = torch.tensor(self.acqf_kwargs["lb"])
self.acqf_kwargs["lb"] = self.acqf_kwargs["lb"].to(model.device)
if "ub" in self.acqf_kwargs:
if not isinstance(self.acqf_kwargs["ub"], torch.Tensor):
self.acqf_kwargs["ub"] = torch.tensor(self.acqf_kwargs["ub"])
self.acqf_kwargs["ub"] = self.acqf_kwargs["ub"].to(model.device)
if self.acqf in self.baseline_requiring_acqfs:
return self.acqf(model, model.train_inputs[0], **self.acqf_kwargs)
else:
return self.acqf(model=model, **self.acqf_kwargs)
[docs] @classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Get configuration options for the generator.
Args:
config (Config): Configuration object.
name (str, optional): Name of the generator, defaults to None. Ignored.
options (Dict[str, Any], optional): Additional options, defaults to None.
Returns:
Dict[str, Any]: Configuration options for the generator.
"""
options = options or {}
classname = cls.__name__
acqf = config.getobj(classname, "acqf", fallback=None)
extra_acqf_args = cls._get_acqf_options(acqf, config)
options.update(
{
"acqf": acqf,
"acqf_kwargs": extra_acqf_args,
}
)
return options