Source code for aepsych.generators.epsilon_greedy_generator

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
from aepsych.config import Config

from ..models.base import ModelProtocol
from .base import AEPsychGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator


[docs]class EpsilonGreedyGenerator(AEPsychGenerator): def __init__( self, lb: torch.Tensor, ub: torch.Tensor, subgenerator: AEPsychGenerator, epsilon: float = 0.1, ) -> None: """Initialize EpsilonGreedyGenerator. Args: lb (torch.Tensor): Lower bounds for the optimization. ub (torch.Tensor): Upper bounds for the optimization. subgenerator (AEPsychGenerator): The generator to use when not exploiting. epsilon (float): The probability of exploration. Defaults to 0.1. """ self.subgenerator = subgenerator self.epsilon = epsilon self.lb = lb self.ub = ub
[docs] @classmethod def from_config(cls, config: Config) -> "EpsilonGreedyGenerator": """Create an EpsilonGreedyGenerator from a Config object. Args: config (Config): Configuration object containing initialization parameters. Returns: EpsilonGreedyGenerator: The generator. """ classname = cls.__name__ lb = torch.tensor(config.getlist(classname, "lb")) ub = torch.tensor(config.getlist(classname, "ub")) subgen_cls = config.getobj( classname, "subgenerator", fallback=OptimizeAcqfGenerator ) subgen = subgen_cls.from_config(config) epsilon = config.getfloat(classname, "epsilon", fallback=0.1) return cls(lb=lb, ub=ub, subgenerator=subgen, epsilon=epsilon)
[docs] def gen(self, num_points: int, model: ModelProtocol) -> torch.Tensor: """Query next point(s) to run by sampling from the subgenerator with probability 1-epsilon, and randomly otherwise. Args: num_points (int): Number of points to query. model (ModelProtocol): Model to use for generating points. """ if num_points > 1: raise NotImplementedError("Epsilon-greedy batched gen is not implemented!") if np.random.uniform() < self.epsilon: sample = np.random.uniform(low=self.lb, high=self.ub) return torch.tensor(sample).reshape(1, -1) else: return self.subgenerator.gen(num_points, model)