Source code for aepsych.transforms.ops.log10_plus

#!/usr/bin/env python3
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional

import numpy as np
import torch
from aepsych.config import Config
from aepsych.transforms.ops.base import Transform
from aepsych.utils import get_bounds
from botorch.models.transforms.input import Log10, subset_transform


[docs]class Log10Plus(Log10, Transform): """Base-10 log transform that we add a constant to the values""" def __init__( self, indices: List[int], constant: float = 0.0, transform_on_train: bool = True, transform_on_eval: bool = True, transform_on_fantasize: bool = True, reverse: bool = False, **kwargs, ) -> None: """Initalize transform Args: indices (List[int]): The indices of the parameters to log transform. constant (float): The constant to add to inputs before log transforming. Defaults to 0.0. transform_on_train (bool): A boolean indicating whether to apply the transforms in train() mode. Default: True. transform_on_eval (bool): A boolean indicating whether to apply the transform in eval() mode. Default: True. transform_on_fantasize (bool): A boolean indicating whether to apply the transform when called from within a `fantasize` call. Default: True. reverse (bool): A boolean indicating whether the forward pass should untransform the inputs. Default: False. **kwargs: Accepted to conform to API. """ super().__init__( indices=indices, transform_on_train=transform_on_train, transform_on_eval=transform_on_eval, transform_on_fantasize=transform_on_fantasize, reverse=reverse, ) self.register_buffer("constant", torch.tensor(constant, dtype=torch.long)) @subset_transform def _transform(self, X: torch.Tensor) -> torch.Tensor: r"""Add the constant then log transform the inputs. Args: X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs. Returns: torch.Tensor: A `batch_shape x n x d`-dim tensor of transformed inputs. """ X = X + (torch.ones_like(X) * self.constant) return X.log10() @subset_transform def _untransform(self, X: torch.Tensor) -> torch.Tensor: r"""Reverse the log transformation then subtract the constant. Args: X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs. Returns: torch.Tensor: A `batch_shape x n x d`-dim tensor of untransformed inputs. """ X = 10.0**X return X - (torch.ones_like(X) * self.constant)
[docs] @classmethod def get_config_options( cls, config: Config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Return a dictionary of the relevant options to initialize a Log10Plus transform for the named parameter within the config. Args: config (Config): Config to look for options in. name (str, optional): Parameter to find options for. options (Dict[str, Any], optional): Options to override from the config. Returns: Dict[str, Any]: A dictionary of options to initialize this class with, including the transformed bounds. """ options = super().get_config_options(config=config, name=name, options=options) # Make sure we have bounds ready if "bounds" not in options: options["bounds"] = get_bounds(config) if "constant" not in options: lb = options["bounds"][0, options["indices"]] if lb < 0.0: constant = np.abs(lb) + 1.0 elif lb < 1.0: constant = 1.0 else: constant = 0.0 options["constant"] = constant return options