Source code for aepsych.generators.acqf_thompson_sampler_generator

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import time
from typing import Dict, Optional

import numpy as np
import torch
from aepsych.models.base import ModelProtocol
from aepsych.utils_logging import getLogger
from numpy.random import choice

from .grid_eval_acqf_generator import GridEvalAcqfGenerator

logger = getLogger()


[docs]class AcqfThompsonSamplerGenerator(GridEvalAcqfGenerator): """Generator that samples points in a grid with probability proportional to an acquisition function value.""" def _gen( self, num_points: int, model: ModelProtocol, fixed_features: Optional[Dict[int, float]] = None, **gen_options, ) -> torch.Tensor: """ Generates the next query points by optimizing the acquisition function. Args: num_points (int): The number of points to query. model (ModelProtocol): The fitted model used to evaluate the acquisition function. fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values. gen_options (dict): Additional options for generating points, including: - "seed": Random seed for reproducibility. Returns: torch.Tensor: Next set of points to evaluate, with shape [num_points x dim]. """ logger.info("Starting gen...") starttime = time.time() grid, acqf_vals = self._eval_acqf( self.samps, model, fixed_features, **gen_options ) acqf_vals -= acqf_vals.min() probability_dist = acqf_vals / acqf_vals.sum() candidate_idx = choice( np.arange(acqf_vals.shape[0]), size=num_points, p=probability_dist.detach().numpy(), ) new_candidate = grid[candidate_idx] logger.info(f"Gen done, time={time.time()-starttime}") return new_candidate