Source code for aepsych.factory.default

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
from configparser import NoOptionError
from typing import List, Optional, Tuple

import gpytorch
import torch
from aepsych.config import Config

from scipy.stats import norm

from .utils import __default_invgamma_concentration, __default_invgamma_rate

# The gamma lengthscale prior is taken from
# https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model

# The lognormal lengscale prior is taken from
# https://arxiv.org/html/2402.02229v3


[docs]def default_mean_covar_factory( config: Optional[Config] = None, dim: Optional[int] = None, stimuli_per_trial: int = 1, ) -> Tuple[gpytorch.means.ConstantMean, gpytorch.kernels.ScaleKernel]: """Default factory for generic GP models Args: config (Config, optional): Object containing bounds (and potentially other config details). dim (int, optional): Dimensionality of the parameter space. Must be provided if config is None. Returns: Tuple[gpytorch.means.Mean, gpytorch.kernels.Kernel]: Instantiated ConstantMean and ScaleKernel with priors based on bounds. """ assert (config is not None) or ( dim is not None ), "Either config or dim must be provided!" assert stimuli_per_trial in (1, 2), "stimuli_per_trial must be 1 or 2!" mean = _get_default_mean_function(config) if config is not None: lb = config.gettensor("default_mean_covar_factory", "lb") ub = config.gettensor("default_mean_covar_factory", "ub") assert lb.shape[0] == ub.shape[0], "bounds shape mismatch!" config_dim: int = lb.shape[0] if dim is not None: assert dim == config_dim, "Provided config does not match provided dim!" else: dim = config_dim covar = _get_default_cov_function(config, dim, stimuli_per_trial) # type: ignore return mean, covar
def _get_default_mean_function( config: Optional[Config] = None, ) -> gpytorch.means.ConstantMean: # default priors fixed_mean = False mean = gpytorch.means.ConstantMean() if config is not None: fixed_mean = config.getboolean( "default_mean_covar_factory", "fixed_mean", fallback=fixed_mean ) if fixed_mean: try: target = config.getfloat("default_mean_covar_factory", "target") mean.constant.requires_grad_(False) mean.constant.copy_(torch.tensor(norm.ppf(target))) except NoOptionError: raise RuntimeError("Config got fixed_mean=True but no target included!") return mean def _get_default_cov_function( config: Optional[Config], dim: int, stimuli_per_trial: int, active_dims: Optional[List[int]] = None, ) -> gpytorch.kernels.Kernel: # default priors lengthscale_prior = "lognormal" if stimuli_per_trial == 1 else "gamma" ls_loc = torch.tensor(math.sqrt(2.0), dtype=torch.float64) ls_scale = torch.tensor(math.sqrt(3.0), dtype=torch.float64) fixed_kernel_amplitude = True if stimuli_per_trial == 1 else False outputscale_prior = "box" kernel = gpytorch.kernels.RBFKernel if config is not None: lengthscale_prior = config.get( "default_mean_covar_factory", "lengthscale_prior", fallback=lengthscale_prior, ) if lengthscale_prior == "lognormal": ls_loc = config.gettensor( "default_mean_covar_factory", "ls_loc", fallback=ls_loc, ) ls_scale = config.gettensor( "default_mean_covar_factory", "ls_scale", fallback=ls_scale ) fixed_kernel_amplitude = config.getboolean( "default_mean_covar_factory", "fixed_kernel_amplitude", fallback=fixed_kernel_amplitude, ) outputscale_prior = config.get( "default_mean_covar_factory", "outputscale_prior", fallback=outputscale_prior, ) kernel = config.getobj("default_mean_covar_factory", "kernel", fallback=kernel) if lengthscale_prior == "invgamma": ls_prior = gpytorch.priors.GammaPrior( concentration=__default_invgamma_concentration, rate=__default_invgamma_rate, transform=lambda x: 1 / x, ) ls_prior_mode = ls_prior.rate / (ls_prior.concentration + 1) elif lengthscale_prior == "gamma": ls_prior = gpytorch.priors.GammaPrior(concentration=3.0, rate=6.0) ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate elif lengthscale_prior == "lognormal": if not isinstance(ls_loc, torch.Tensor): ls_loc = torch.tensor(ls_loc, dtype=torch.float64) if not isinstance(ls_scale, torch.Tensor): ls_scale = torch.tensor(ls_scale, dtype=torch.float64) ls_prior = gpytorch.priors.LogNormalPrior(ls_loc + math.log(dim) / 2, ls_scale) ls_prior_mode = torch.exp(ls_loc - ls_scale**2) else: raise RuntimeError( f"Lengthscale_prior should be invgamma, gamma, or lognormal, got {lengthscale_prior}" ) ls_constraint = gpytorch.constraints.GreaterThan( lower_bound=1e-4, transform=None, initial_value=ls_prior_mode ) covar = kernel( lengthscale_prior=ls_prior, lengthscale_constraint=ls_constraint, ard_num_dims=dim, active_dims=active_dims, ) if not fixed_kernel_amplitude: if outputscale_prior == "gamma": os_prior = gpytorch.priors.GammaPrior(concentration=2.0, rate=0.15) elif outputscale_prior == "box": os_prior = gpytorch.priors.SmoothedBoxPrior(a=1, b=4) else: raise RuntimeError( f"Outputscale_prior should be gamma or box, got {outputscale_prior}" ) covar = gpytorch.kernels.ScaleKernel( covar, outputscale_prior=os_prior, outputscale_constraint=gpytorch.constraints.GreaterThan(1e-4), ) return covar