Source code for aepsych.generators.epsilon_greedy_generator

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
from aepsych.config import Config

from ..models.base import ModelProtocol
from .base import AEPsychGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator


[docs]class EpsilonGreedyGenerator(AEPsychGenerator): def __init__(self, subgenerator: AEPsychGenerator, epsilon: float = 0.1): self.subgenerator = subgenerator self.epsilon = epsilon
[docs] @classmethod def from_config(cls, config: Config): classname = cls.__name__ subgen_cls = config.getobj( classname, "subgenerator", fallback=OptimizeAcqfGenerator ) subgen = subgen_cls.from_config(config) epsilon = config.getfloat(classname, "epsilon", fallback=0.1) return cls(subgenerator=subgen, epsilon=epsilon)
[docs] def gen(self, num_points: int, model: ModelProtocol): if num_points > 1: raise NotImplementedError("Epsilon-greedy batched gen is not implemented!") if np.random.uniform() < self.epsilon: sample = np.random.uniform(low=model.lb, high=model.ub) return torch.tensor(sample).reshape(1, -1) else: return self.subgenerator.gen(num_points, model)