Source code for aepsych.generators.epsilon_greedy_generator

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import numpy as np
import torch

from ..models.model_protocol import ModelProtocol
from .base import AEPsychGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator


[docs]class EpsilonGreedyGenerator(AEPsychGenerator): def __init__( self, lb: torch.Tensor, ub: torch.Tensor, subgenerator: AEPsychGenerator, epsilon: float = 0.1, ) -> None: """Initialize EpsilonGreedyGenerator. Args: lb (torch.Tensor): Lower bounds for the optimization. ub (torch.Tensor): Upper bounds for the optimization. subgenerator (AEPsychGenerator): The generator to use when not exploiting. epsilon (float): The probability of exploration. Defaults to 0.1. """ self.subgenerator = subgenerator self.epsilon = epsilon self.lb = lb self.ub = ub self.dim = len(lb)
[docs] @classmethod def get_config_options( cls, config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Find the config options for the generator. Args: config (Config): Config to look for options in. name (str, optional): Unused, kept for API conformity. options (Dict[str, Any], optional): Existing options, any key in options will be ignored from the config. Return: Dict[str, Any]: A dictionary of options to initialize the generator. """ options = super().get_config_options(config, name, options) if "subgenerator" not in options: # Missing subgenerator options["subgenerator"] = OptimizeAcqfGenerator.from_config(config) return options
[docs] def gen( self, num_points: int, model: ModelProtocol, fixed_features: Optional[Dict[int, float]] = None, **kwargs, ) -> torch.Tensor: """Query next point(s) to run by sampling from the subgenerator with probability 1-epsilon, and randomly otherwise. Args: num_points (int): Number of points to query. model (ModelProtocol): Model to use for generating points. fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values. **kwargs: Passed to subgenerator if not exploring """ if num_points > 1: raise NotImplementedError("Epsilon-greedy batched gen is not implemented!") if np.random.uniform() < self.epsilon: sample_ = np.random.uniform(low=self.lb, high=self.ub) sample = torch.tensor(sample_).reshape(1, -1) if fixed_features is not None: for key, value in fixed_features.items(): sample[:, key] = value return sample else: return self.subgenerator.gen( num_points, model, fixed_features=fixed_features, **kwargs )