Source code for aepsych.transforms.ops.fixed

#!/usr/bin/env python3
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Optional, Union

import torch
from aepsych.config import Config
from aepsych.transforms.ops.base import StringParameterMixin, Transform


[docs]class Fixed(Transform, StringParameterMixin, torch.nn.Module): def __init__( self, indices: List[int], values: List[Union[float, int]], string_map: Optional[Dict[int, List[str]]] = None, transform_on_train: bool = True, transform_on_eval: bool = True, transform_on_fantasize: bool = True, reverse: bool = False, **kwargs, ) -> None: """Initialize a fixed transform. It will add and remove fixed values from tensors. Args: indices (List[int]): The indices of the parameters to be fixed. values (List[Union[float, int]]): The values of the fixed parameters. string_map (Dict[int, List[str]], optional): A dictionary to allow some fixed elements to represent one element of a categorical parameter. transform_on_train (bool): A boolean indicating whether to apply the transforms in train() mode. Default: True. transform_on_eval (bool): A boolean indicating whether to apply the transform in eval() mode. Default: True. transform_on_fantasize (bool): A boolean indicating whether to apply the transform when called from within a `fantasize` call. Default: True. reverse (bool): A boolean indicating whether the forward pass should untransform the inputs. Default: False. **kwargs: Accepted to conform to API. """ # Turn indices and values into tensors and sort indices_ = torch.tensor(indices, dtype=torch.long) values_ = torch.tensor(values, dtype=torch.float64) # Sort indices and values sort_idx = torch.argsort(indices_) indices_ = indices_[sort_idx] values_ = values_[sort_idx] super().__init__() self.register_buffer("indices", indices_) self.register_buffer("values", values_) self.transform_on_train = transform_on_train self.transform_on_eval = transform_on_eval self.transform_on_fantasize = transform_on_fantasize self.reverse = reverse self.string_map = string_map def _transform(self, X: torch.Tensor) -> torch.Tensor: r"""Transform the input Tensor by popping out the fixed parameters at the specified indices. Args: X (torch.Tensor): A `batch_shape x n x d`-dim tensor of inputs. Returns: torch.Tensor: The input tensor with fixed parameters removed. """ X = X.clone() mask = ~torch.isin(torch.arange(X.shape[1]), self.indices) X = X[:, mask] return X def _untransform(self, X: torch.Tensor) -> torch.Tensor: r"""Transform the input tensor by adding back in the fixed parameters at the specified indices. Args: X (torch.Tensor): A `batch_shape x n x d`-dim tensor of transformed inputs. Returns: torch.Tensor: The same tensor as the input with the fixed parameters added back in. """ X = X.clone() for i, idx in enumerate(self.indices): pre_fixed = X[:, :idx] post_fixed = X[:, idx:] fixed = torch.tile(self.values[i], (X.shape[0], 1)) X = torch.cat((pre_fixed, fixed, post_fixed), dim=1) return X
[docs] @classmethod def get_config_options( cls, config: Config, name: Optional[str] = None, options: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Return a dictionary of the relevant options to initialize a Fixed parameter transform for the named parameter within the config. Args: config (Config): Config to look for options in. name (str, optional): Parameter to find options for. options (Dict[str, Any], optional): Options to override from the config. Returns: Dict[str, Any]: A dictionary of options to initialize this class with, including the transformed bounds. """ options = super().get_config_options(config=config, name=name, options=options) if name is None: raise ValueError(f"{name} must be set to initialize a transform.") if "values" not in options: value = config[name].get("value") if value is None: raise ValueError(f"Value option not found in {name} section.") try: options["values"] = [float(value)] except ValueError: # Probably a string, so we treat it as categorical parameter fixed options["string_map"] = {options["indices"][0]: [value]} options["values"] = [0] return options