Source code for aepsych.generators.optimize_acqf_generator

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import inspect
import time
from typing import Any, Dict, Optional

import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import ModelProtocol
from aepsych.utils_logging import getLogger
from botorch.acquisition import (
    AcquisitionFunction,
    LogNoisyExpectedImprovement,
    NoisyExpectedImprovement,
    qLogNoisyExpectedImprovement,
    qNoisyExpectedImprovement,
)
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption
from botorch.optim import optimize_acqf

logger = getLogger()


[docs]class OptimizeAcqfGenerator(AEPsychGenerator): """Generator that chooses points by minimizing an acquisition function.""" baseline_requiring_acqfs = [ NoisyExpectedImprovement, LogNoisyExpectedImprovement, qNoisyExpectedImprovement, qLogNoisyExpectedImprovement, ] def __init__( self, lb: torch.Tensor, ub: torch.Tensor, acqf: AcquisitionFunction, acqf_kwargs: Optional[Dict[str, Any]] = None, restarts: int = 10, samps: int = 1000, max_gen_time: Optional[float] = None, stimuli_per_trial: int = 1, ) -> None: """Initialize OptimizeAcqfGenerator. Args: lb (torch.Tensor): Lower bounds for the optimization. ub (torch.Tensor): Upper bounds for the optimization. acqf (AcquisitionFunction): Acquisition function to use. acqf_kwargs (Dict[str, object], optional): Extra arguments to pass to acquisition function. Defaults to no arguments. restarts (int): Number of restarts for acquisition function optimization. Defaults to 10. samps (int): Number of samples for quasi-random initialization of the acquisition function optimizer. Defaults to 1000. max_gen_time (float, optional): Maximum time (in seconds) to optimize the acquisition function. Defaults to None. stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1. """ if acqf_kwargs is None: acqf_kwargs = {} self.acqf = acqf self.acqf_kwargs = acqf_kwargs self.restarts = restarts self.samps = samps self.max_gen_time = max_gen_time self.stimuli_per_trial = stimuli_per_trial self.lb = lb self.ub = ub def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFunction: """ Instantiates the acquisition function with the specified model and additional arguments. Args: model (ModelProtocol): The model to use with the acquisition function. Returns: AcquisitionFunction: Configured acquisition function. """ if self.acqf == AnalyticExpectedUtilityOfBestOption: return self.acqf(pref_model=model) if hasattr(model, "device"): if "lb" in self.acqf_kwargs: if not isinstance(self.acqf_kwargs["lb"], torch.Tensor): self.acqf_kwargs["lb"] = torch.tensor(self.acqf_kwargs["lb"]) self.acqf_kwargs["lb"] = self.acqf_kwargs["lb"].to(model.device) if "ub" in self.acqf_kwargs: if not isinstance(self.acqf_kwargs["ub"], torch.Tensor): self.acqf_kwargs["ub"] = torch.tensor(self.acqf_kwargs["ub"]) self.acqf_kwargs["ub"] = self.acqf_kwargs["ub"].to(model.device) if self.acqf in self.baseline_requiring_acqfs: return self.acqf(model, model.train_inputs[0], **self.acqf_kwargs) else: return self.acqf(model=model, **self.acqf_kwargs)
[docs] def gen(self, num_points: int, model: ModelProtocol, **gen_options) -> torch.Tensor: """Query next point(s) to run by optimizing the acquisition function. Args: num_points (int): Number of points to query. model (ModelProtocol): Fitted model of the data. Returns: torch.Tensor: Next set of point(s) to evaluate, [num_points x dim]. """ if self.stimuli_per_trial == 2: qbatch_points = self._gen( num_points=num_points * 2, model=model, **gen_options ) # output of super() is (q, dim) but the contract is (num_points, dim, 2) # so we need to split q into q and pairs and then move the pair dim to the end return qbatch_points.reshape(num_points, 2, -1).swapaxes(-1, -2) else: return self._gen(num_points=num_points, model=model, **gen_options)
def _gen( self, num_points: int, model: ModelProtocol, **gen_options: Dict[str, Any] ) -> torch.Tensor: """ Generates the next query points by optimizing the acquisition function. Args: num_points (int): Number of points to query. model (ModelProtocol): Fitted model of the data. gen_options (Dict[str, Any]): Additional options for generating points, such as custom configurations. Returns: torch.Tensor: Next set of points to evaluate, with shape [num_points x dim]. """ # eval should be inherited from superclass model.eval() # type: ignore acqf = self._instantiate_acquisition_fn(model) if hasattr(model, "device"): self.lb = self.lb.to(model.device) self.ub = self.ub.to(model.device) logger.info("Starting gen...") starttime = time.time() new_candidate, _ = optimize_acqf( acq_function=acqf, bounds=torch.stack([self.lb, self.ub]), q=num_points, num_restarts=self.restarts, raw_samples=self.samps, timeout_sec=self.max_gen_time, **gen_options, ) logger.info(f"Gen done, time={time.time()-starttime}") return new_candidate
[docs] @classmethod def from_config(cls, config: Config) -> "OptimizeAcqfGenerator": """ Creates an instance of OptimizeAcqfGenerator from a configuration object. Args: config (Config): Configuration object containing initialization parameters. Returns: OptimizeAcqfGenerator: A configured instance of OptimizeAcqfGenerator with specified acquisition function, restart and sample parameters, maximum generation time, and stimuli per trial. """ classname = cls.__name__ lb = config.gettensor(classname, "lb") ub = config.gettensor(classname, "ub") acqf = config.getobj(classname, "acqf", fallback=None) extra_acqf_args = cls._get_acqf_options(acqf, config) stimuli_per_trial = config.getint(classname, "stimuli_per_trial") restarts = config.getint(classname, "restarts", fallback=10) samps = config.getint(classname, "samps", fallback=1000) max_gen_time = config.getfloat(classname, "max_gen_time", fallback=None) return cls( lb=lb, ub=ub, acqf=acqf, acqf_kwargs=extra_acqf_args, restarts=restarts, samps=samps, max_gen_time=max_gen_time, stimuli_per_trial=stimuli_per_trial, )