#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import abc
import ast
import configparser
import inspect
import json
import logging
import re
import typing
import warnings
from types import ModuleType, NoneType
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
Sequence,
TypeVar,
)
import botorch
import gpytorch
import numpy as np
import torch
from aepsych.version import __version__
_T = TypeVar("_T")
_ET = TypeVar("_ET")
DEPRECATED_OBJS = [
"MonotonicRejectionGenerator",
"MonotonicMCPosteriorVariance",
"MonotonicBernoulliMCMutualInformation",
"MonotonicMCLSE",
"MonotonicRejectionGP",
"monotonic_mean_covar_factory",
]
[docs]class Config(configparser.ConfigParser):
# names in these packages can be referred to by string name
registered_names: ClassVar[Dict[str, object]] = {}
def __init__(
self,
config_dict: Optional[Mapping[str, Any]] = None,
config_fnames: Optional[Sequence[str]] = None,
config_str: Optional[str] = None,
) -> None:
"""Initialize the AEPsych config object. This can be used to instantiate most
objects in AEPsych by calling object.from_config(config).
Args:
config_dict (Mapping[str, str], optional): Mapping to build configuration from.
Keys are section names, values are dictionaries with keys and values that
should be present in the section. Defaults to None.
config_fnames (Sequence[str], optional): List of INI filenames to load
configuration from. Defaults to None.
config_str (str, optional): String formatted as an INI file to load configuration
from. Defaults to None.
"""
super().__init__(
inline_comment_prefixes=("#"),
empty_lines_in_values=False,
default_section="common",
interpolation=configparser.ExtendedInterpolation(),
converters={
"list": self._str_to_list,
"tensor": self._str_to_tensor,
"obj": self._str_to_obj,
"array": self._str_to_array,
},
allow_no_value=True,
)
self.update(
config_dict=config_dict,
config_fnames=config_fnames,
config_str=config_str,
)
def _get(
self,
section: str,
conv: _T,
option: str,
*,
raw: bool = False,
vars: Optional[Dict[str, Any]] = None,
fallback: _T = configparser._UNSET,
**kwargs,
):
"""
Override configparser to:
1. Return from common if a section doesn't exist. This comes
up any time we have a module fully configured from the
common/default section.
2. Pass extra **kwargs to the converter.
Args:
section (str): Section to get the option from.
conv (_T): Converter to use.
option (str): Option to get.
raw (bool): Whether to return the raw value. Defaults to False.
vars (Dict[str, Any], optional): Optional dictionary to use for interpolation. Defaults to None.
fallback (_T): Value to return if the option is not found. Defaults to configparser._UNSET.
Returns:
_T: Converted value of the option.
"""
try:
return conv(
self.get(
section=section,
option=option,
raw=raw,
vars=vars,
fallback=fallback,
),
**kwargs,
)
except configparser.NoSectionError:
return conv(
self.get(
section="common",
option=option,
raw=raw,
vars=vars,
fallback=fallback,
),
**kwargs,
)
# Convert config into a dictionary (eliminate duplicates from defaulted 'common' section.)
[docs] def to_dict(self, deduplicate: bool = True) -> Dict[str, Any]:
"""Convert the config into a dictionary.
Args:
deduplicate (bool): Whether to deduplicate the 'common' section. Defaults to True.
Returns:
dict: Dictionary representation of the config.
"""
_dict: Dict[str, Any] = {}
for section in self:
_dict[section] = {}
for setting in self[section]:
if deduplicate and section != "common" and setting in self["common"]:
continue
_dict[section][setting] = self[section][setting]
return _dict
# Turn the metadata section into JSON.
# Turn the entire config into JSON format.
[docs] def jsonifyAll(self) -> str:
"""Turn the entire config into JSON format.
Returns:
str: JSON representation of the entire config.
"""
configdict = self.to_dict()
return json.dumps(configdict)
[docs] def update(
self,
config_dict: Optional[Mapping[str, str]] = None,
config_fnames: Sequence[str] = None,
config_str: str = None,
) -> None:
"""Update this object with a new configuration.
Args:
config_dict (Mapping[str, str], optional): Mapping to build configuration from.
Keys are section names, values are dictionaries with keys and values that
should be present in the section. Defaults to None.
config_fnames (Sequence[str]): List of INI filenames to load
configuration from. Defaults to None.
config_str (str): String formatted as an INI file to load configuration
from. Defaults to None.
"""
if config_dict is not None:
self.read_dict(config_dict)
if config_fnames is not None:
read_ok = self.read(config_fnames)
if len(read_ok) < 1:
raise FileNotFoundError
if config_str is not None:
self.read_string(config_str)
if "parnames" in self["common"]: # it's possible to pass no parnames
try:
par_names = self.getlist(
"common", "parnames", element_type=str, fallback=[]
)
lb = [None] * len(par_names)
ub = [None] * len(par_names)
for i, par_name in enumerate(par_names):
# Validate the parameter-specific block
self._check_param_settings(par_name)
lb[i] = self[par_name].get("lower_bound", fallback="0")
ub[i] = self[par_name].get("upper_bound", fallback="1")
self["common"]["lb"] = f"[{', '.join(lb)}]"
self["common"]["ub"] = f"[{', '.join(ub)}]"
except ValueError:
# Check if ub/lb exists in common
if "ub" in self["common"] and "lb" in self["common"]:
logging.warning(
"Parameter-specific bounds are incomplete, falling back to ub/lb in [common]"
)
else:
raise ValueError(
"Missing ub or lb in [common] with incomplete parameter-specific bounds, cannot fallback!"
)
def _str_to_list(
self, v: str, element_type: Callable[[_ET], _ET] = float
) -> List[_T]:
"""Convert a string to a list.
Args:
v (str): String to convert.
element_type (Callable[[_ET], _ET]): Type of the elements in the list. Defaults to float.
Returns:
List[_T]: List of elements of type _T.
"""
v = re.sub(r",]", "]", v)
if re.search(r"^\[.*\]$", v, flags=re.DOTALL):
if v == "[]": # empty list
return []
else:
return [element_type(i.strip()) for i in v[1:-1].split(",")]
else:
return [v.strip()]
def _str_to_array(self, v: str) -> np.ndarray:
"""Convert a string to a numpy array.
Args:
v (str): String to convert.
Returns:
np.ndarray: Numpy array representation of the string.
"""
v = ast.literal_eval(v)
return np.array(v, dtype=float)
def _str_to_tensor(self, v: str) -> torch.Tensor:
"""Convert a string to a torch tensor.
Args:
v (str): String to convert.
Returns:
torch.Tensor: Tensor representation of the string.
"""
return torch.Tensor(self._str_to_array(v)).to(torch.float64)
def _str_to_obj(self, v: str, fallback_type: _T = str, warn: bool = True) -> object:
"""Convert a string to an object.
Args:
v (str): String to convert.
fallback_type (_T): Type to fallback to if the object is not found. Defaults to str.
warn (bool): Whether to warn if the object is not found. Defaults to True.
Returns:
object: Object representation of the string.
"""
try:
return self.registered_names[v]
except KeyError:
if warn:
if v in DEPRECATED_OBJS:
raise TypeError(
f"Object {v} is deprecated and no longer supported!"
)
else:
warnings.warn(f'No known object "{v}"!')
return fallback_type(v)
def _check_param_settings(self, param_name: str) -> None:
"""Check parameter-specific blocks have the correct settings, raises a ValueError if not.
Args:
param_name (str): Parameter block to check.
"""
# Check if the config block exists at all
if param_name not in self:
raise ValueError(f"Parameter {param_name} is missing its own config block.")
param_block = self[param_name]
# Checking if param_type is set
if "par_type" not in param_block:
raise ValueError(f"Parameter {param_name} is missing the par_type setting.")
# Each parameter type has a different set of required settings
if param_block["par_type"] == "continuous":
# Check if bounds exist
if "lower_bound" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the lower_bound setting."
)
if "upper_bound" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the upper_bound setting."
)
elif param_block["par_type"] == "integer":
# Check if bounds exist and actaully integers
if "lower_bound" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the lower_bound setting."
)
if "upper_bound" not in param_block:
raise ValueError(
f"Parameter {param_name} is missing the upper_bound setting."
)
try:
if not (
self.getint(param_name, "lower_bound") % 1 == 0
and self.getint(param_name, "upper_bound") % 1 == 0
):
raise ParameterConfigError(
f"Parameter {param_name} has non-integer bounds."
)
except ValueError:
raise ParameterConfigError(
f"Parameter {param_name} has non-discrete bounds."
)
elif param_block["par_type"] == "binary":
if "lower_bound" in param_block or "upper_bound" in param_block:
raise ParameterConfigError(
f"Parameter {param_name} is binary and shouldn't have bounds."
)
elif param_block["par_type"] == "fixed":
if "value" not in param_block:
raise ParameterConfigError(
f"Parameter {param_name} is fixed and needs to have value set."
)
else:
raise ParameterConfigError(
f"Parameter {param_name} has an unsupported parameter type {param_block['par_type']}."
)
def __repr__(self) -> str:
"""Return a string representation of the config.
Returns:
str: String representation of the config.
"""
return f"Config at {hex(id(self))}: \n {str(self)}"
[docs] @classmethod
def register_module(cls: _T, module: ModuleType) -> None:
"""Register a module with Config so that objects in it can
be referred to by their string name in config files.
Args:
module (ModuleType): Module to register.
"""
cls.registered_names.update(
{
name: getattr(module, name)
for name in module.__all__
if not isinstance(getattr(module, name), ModuleType)
}
)
[docs] @classmethod
def register_object(cls: _T, obj: object) -> None:
"""Register an object with Config so that it can be
referred to by its string name in config files.
Args:
obj (object): Object to register.
"""
if obj.__name__ in cls.registered_names.keys():
warnings.warn(
f"Registering {obj.__name__} but already"
+ f"have {cls.registered_names[obj.__name__]}"
+ "registered under that name!"
)
cls.registered_names.update({obj.__name__: obj})
[docs] def get_section(self, section: str) -> Dict[str, Any]:
"""Get a section of the config.
Args:
section (str): Section to get.
Returns:
Dict[str, Any]: Dictionary representation of the section.
"""
sec = {}
for setting in self[section]:
if section != "common" and setting in self["common"]:
continue
sec[setting] = self[section][setting]
return sec
def __str__(self):
"""Return a string representation of the config."""
_str = ""
for section in self:
sec = self.get_section(section)
_str += f"[{section}]\n"
for setting in sec:
_str += f"{setting} = {self[section][setting]}\n"
return _str
[docs]class ConfigurableMixin(abc.ABC):
[docs] @classmethod
def get_config_options(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Return a dictionary of the relevant options to initialize this class from the
config, even if it is outside of the named section. By default, this will look
for options in name based on the __init__'s arguments/defaults.
Args:
config (Config): Config to look for options in.
name (str, optional): Primary section to look for options for this class and
the name to infer options from other sections in the config.
options (Dict[str, Any], optional): Options to override from the config,
defaults to None.
Return:
Dict[str, Any]: A dictionary of options to initialize this class.
"""
if name is None:
name = cls.__name__
def _sort_types(annotations):
# Rebuild the annotations, prefering float, int, string, then the rest
reordered = []
if float in annotations:
reordered += [float]
if int in annotations:
reordered += [int]
if str in annotations:
reordered += [str]
reordered += [elem for elem in annotations if elem not in [float, int, str]]
return tuple(reordered)
args = inspect.signature(cls, eval_str=True).parameters
options = {}
for key, signature in args.items():
# Used as fallback
value = signature.default
if (
typing.get_origin(signature.annotation) is typing.Union
): # Includes Optional
annotations = typing.get_args(signature.annotation)
annotations = _sort_types(annotations)
else:
annotations = (signature.annotation,)
for annotation in annotations:
try:
# Tensor
if annotation is torch.Tensor:
value = config.gettensor(name, key)
# Numpy array
elif annotation is np.ndarray:
value = config.getarray(name, key)
# Default list
elif annotation is list:
try:
value = config.getlist(name, key, element_type=float)
except ValueError:
value = config.getlist(name, key, element_type=str)
# Generic List[...]
elif typing.get_origin(annotation) is list:
element_types = typing.get_args(annotation)
element_type = _sort_types(element_types)[0]
value = config.getlist(
name,
key,
element_type=element_type,
)
# String
elif annotation is str:
value = config.get(name, key)
# Int
elif annotation is int:
value = config.getint(name, key)
# Float
elif annotation is float:
value = config.getfloat(name, key)
# Bool
elif annotation is bool:
value = config.getboolean(name, key)
# Object
elif inspect.isclass(annotation):
object_cls = config.getobj(name, key)
if ConfigurableMixin in object_cls.__bases__:
value = object_cls.from_config(config, object_cls.__name__)
else:
value = object_cls
# None type
elif annotation is NoneType:
value = None
# We essentially keep trying until we succeed
break
except (ValueError, configparser.NoOptionError):
pass
options[key] = value
return options
[docs] @classmethod
def from_config(
cls,
config: Config,
name: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> "ConfigurableMixin":
"""
Return a initialized instance of this class using the config and the name.
Args:
config (Config): Config to use to initialize this class.
name (str, optional): Name of section to look in first for this class.
options (Dict[str, Any], optional): Options to override from the config,
defaults to None.
Return:
ConfigurableMixin: Initialized class based on config and name.
"""
return cls(**cls.get_config_options(config=config, name=name, options=options))
Config.register_module(gpytorch.likelihoods)
Config.register_module(gpytorch.kernels)
Config.register_module(botorch.acquisition)
Config.register_module(botorch.acquisition.multi_objective)
Config.registered_names["None"] = None
[docs]class ParameterConfigError(Exception):
pass