#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, Sequence
import torch
from aepsych.acquisition.monotonic_rejection import MonotonicMCAcquisition
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.monotonic_rejection_gp import MonotonicRejectionGP
from botorch.logging import logger
from botorch.optim.initializers import gen_batch_initial_conditions
from botorch.optim.utils import columnwise_clamp, fix_features
[docs]def default_loss_constraint_fun(
    loss: torch.Tensor, candidates: torch.Tensor
) -> torch.Tensor:
    """Identity transform for constrained optimization.
    This simply returns loss as-is. Write your own versions of this
    for constrained optimization by e.g. interior point method.
    Args:
        loss (torch.Tensor): Value of loss at candidate points.
        candidates (torch.Tensor): Location of candidate points.
    Returns:
        torch.Tensor: New loss (unchanged)
    """
    return loss 
[docs]class MonotonicRejectionGenerator(AEPsychGenerator[MonotonicRejectionGP]):
    """Generator specifically to be used with MonotonicRejectionGP, which generates new points to sample by minimizing
    an acquisition function through stochastic gradient descent."""
    def __init__(
        self,
        acqf: MonotonicMCAcquisition,
        acqf_kwargs: Optional[Dict[str, Any]] = None,
        model_gen_options: Optional[Dict[str, Any]] = None,
        explore_features: Optional[Sequence[int]] = None,
    ) -> None:
        """Initialize MonotonicRejectionGenerator.
        Args:
            acqf (AcquisitionFunction): Acquisition function to use.
            acqf_kwargs (Dict[str, object], optional): Extra arguments to
                pass to acquisition function. Defaults to no arguments.
            model_gen_options: Dictionary with options for generating candidate, such as
                SGD parameters. See code for all options and their defaults.
            explore_features: List of features that will be selected randomly and then
                fixed for acquisition fn optimization.
        """
        if acqf_kwargs is None:
            acqf_kwargs = {}
        self.acqf = acqf
        self.acqf_kwargs = acqf_kwargs
        self.model_gen_options = model_gen_options
        self.explore_features = explore_features
    def _instantiate_acquisition_fn(self, model: MonotonicRejectionGP):
        return self.acqf(
            model=model,
            deriv_constraint_points=model._get_deriv_constraint_points(),
            **self.acqf_kwargs,
        )
[docs]    def gen(
        self,
        num_points: int,  # Current implementation only generates 1 point at a time
        model: MonotonicRejectionGP,
    ):
        """Query next point(s) to run by optimizing the acquisition function.
        Args:
            num_points (int, optional): Number of points to query.
            model (AEPsychMixin): Fitted model of the data.
        Returns:
            np.ndarray: Next set of point(s) to evaluate, [num_points x dim].
        """
        options = self.model_gen_options or {}
        num_restarts = options.get("num_restarts", 10)
        raw_samples = options.get("raw_samples", 1000)
        verbosity_freq = options.get("verbosity_freq", -1)
        lr = options.get("lr", 0.01)
        momentum = options.get("momentum", 0.9)
        nesterov = options.get("nesterov", True)
        epochs = options.get("epochs", 50)
        milestones = options.get("milestones", [25, 40])
        gamma = options.get("gamma", 0.1)
        loss_constraint_fun = options.get(
            "loss_constraint_fun", default_loss_constraint_fun
        )
        # Augment bounds with deriv indicator
        bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1)
        # Fix deriv indicator to 0 during optimization
        fixed_features = {(bounds.shape[1] - 1): 0.0}
        # Fix explore features to random values
        if self.explore_features is not None:
            for idx in self.explore_features:
                val = (
                    bounds[0, idx]
                    + torch.rand(1, dtype=bounds.dtype)
                    * (bounds[1, idx] - bounds[0, idx])
                ).item()
                fixed_features[idx] = val
                bounds[0, idx] = val
                bounds[1, idx] = val
        acqf = self._instantiate_acquisition_fn(model)
        # Initialize
        batch_initial_conditions = gen_batch_initial_conditions(
            acq_function=acqf,
            bounds=bounds,
            q=1,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
        )
        clamped_candidates = columnwise_clamp(
            X=batch_initial_conditions, lower=bounds[0], upper=bounds[1]
        ).requires_grad_(True)
        candidates = fix_features(clamped_candidates, fixed_features)
        optimizer = torch.optim.SGD(
            params=[clamped_candidates], lr=lr, momentum=momentum, nesterov=nesterov
        )
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=gamma
        )
        # Optimize
        for epoch in range(epochs):
            loss = -acqf(candidates).sum()
            # adjust loss based on constraints on candidates
            loss = loss_constraint_fun(loss, candidates)
            if verbosity_freq > 0 and epoch % verbosity_freq == 0:
                logger.info("Iter: {} - Value: {:.3f}".format(epoch, -(loss.item())))
            def closure():
                optimizer.zero_grad()
                loss.backward(
                    retain_graph=True
                )  # Variational model requires retain_graph
                return loss
            optimizer.step(closure)
            clamped_candidates.data = columnwise_clamp(
                X=clamped_candidates, lower=bounds[0], upper=bounds[1]
            )
            candidates = fix_features(clamped_candidates, fixed_features)
            lr_scheduler.step()
        # Extract best point
        with torch.no_grad():
            batch_acquisition = acqf(candidates)
        best = torch.argmax(batch_acquisition.view(-1), dim=0)
        Xopt = candidates[best][:, :-1].detach()
        return Xopt 
[docs]    @classmethod
    def from_config(cls, config: Config):
        classname = cls.__name__
        acqf = config.getobj("common", "acqf", fallback=None)
        extra_acqf_args = cls._get_acqf_options(acqf, config)
        options = {}
        options["num_restarts"] = config.getint(classname, "restarts", fallback=10)
        options["raw_samples"] = config.getint(classname, "samps", fallback=1000)
        options["verbosity_freq"] = config.getint(
            classname, "verbosity_freq", fallback=-1
        )
        options["lr"] = config.getfloat(classname, "lr", fallback=0.01)  # type: ignore
        options["momentum"] = config.getfloat(classname, "momentum", fallback=0.9)  # type: ignore
        options["nesterov"] = config.getboolean(classname, "nesterov", fallback=True)
        options["epochs"] = config.getint(classname, "epochs", fallback=50)
        options["milestones"] = config.getlist(
            classname, "milestones", fallback=[25, 40]  # type: ignore
        )
        options["gamma"] = config.getfloat(classname, "gamma", fallback=0.1)  # type: ignore
        options["loss_constraint_fun"] = config.getobj(
            classname, "loss_constraint_fun", fallback=default_loss_constraint_fun
        )
        explore_features = config.getlist(classname, "explore_idxs", fallback=None)  # type: ignore
        return cls(
            acqf=acqf,
            acqf_kwargs=extra_acqf_args,
            model_gen_options=options,
            explore_features=explore_features,
        )