Source code for aepsych.strategy.sequential

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import warnings
from typing import Any, List, Optional, Union

import numpy as np
import torch
from aepsych.config import Config
from aepsych.utils_logging import getLogger

from .strategy import Strategy

logger = getLogger()


[docs]class SequentialStrategy(object): """Runs a sequence of strategies defined by its config All getter methods defer to the current strat Args: strat_list (list[Strategy]): TODO make this nicely typed / doc'd """ def __init__(self, strat_list: List[Strategy]) -> None: """Initialize the SequentialStrategy object. Args: strat_list (List[Strategy]): The list of strategies. """ self.strat_list = strat_list self._strat_idx = 0 self._suggest_count = 0 self.x: Optional[torch.Tensor] self.y: Optional[torch.Tensor] @property def _strat(self) -> Strategy: """Get the current strategy. Returns: Strategy: The current strategy. """ return self.strat_list[self._strat_idx] def __getattr__(self, name: str) -> Any: """Get the attribute of the current strategy. Args: name (str): The name of the attribute. Returns: Any: The attribute of the current strategy. """ # return current strategy's attr if it's not a container attr if "strat_list" not in vars(self): raise AttributeError("Have no strategies in container, what happened?") return getattr(self._strat, name) def _make_next_strat(self) -> None: """Switch to the next strategy.""" if (self._strat_idx + 1) >= len(self.strat_list): warnings.warn( "Ran out of generators, staying on final generator!", RuntimeWarning ) return # populate new model with final data from last model assert ( self.x is not None and self.y is not None ), "Cannot initialize next strategy; no data has been given!" self.strat_list[self._strat_idx + 1].add_data(self.x, self.y) self._suggest_count = 0 self._strat_idx = self._strat_idx + 1
[docs] def gen(self, num_points: int = 1, **kwargs) -> torch.Tensor: """Generate the next set of points to evaluate. Args: num_points (int): The number of points to generate. Defaults to 1. Returns: torch.Tensor: The next set of points to evaluate. """ if self._strat.finished: self._make_next_strat() self._suggest_count = self._suggest_count + num_points return self._strat.gen(num_points=num_points, **kwargs)
[docs] def finish(self) -> None: """Finish the strategy.""" self._strat.finish()
@property def finished(self) -> bool: """Check if the strategy is finished. Returns: bool: True if the strategy is finished, False otherwise. """ return self._strat_idx == (len(self.strat_list) - 1) and self._strat.finished
[docs] def add_data( self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor] ) -> None: """Add new data points to the strategy. Args: x (Union[np.ndarray, torch.Tensor]): The input data points. y (Union[np.ndarray, torch.Tensor]): The output data points. """ self._strat.add_data(x, y)
[docs] @classmethod def from_config(cls, config: Config) -> SequentialStrategy: """Create a SequentialStrategy object from a configuration object. Args: config (Config): The configuration object. Returns: SequentialStrategy: The SequentialStrategy object. """ strat_names = config.getlist("common", "strategy_names", element_type=str) # ensure strat_names are unique assert len(strat_names) == len( set(strat_names) ), f"Strategy names {strat_names} are not all unique!" strats = [] for name in strat_names: strat = Strategy.from_config(config, str(name)) strats.append(strat) return cls(strat_list=strats)