#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import io
from typing import Callable, Tuple
import numpy as np
import pandas as pd
import torch
from scipy.interpolate import CubicSpline, interp1d
# manually scraped data from doi:10.1007/s10162-013-0396-x fig 2
raw = """\
dubno_data = pd.read_csv(io.StringIO(raw))
[docs]def make_songetal_threshfun(
x: torch.Tensor, y: torch.Tensor
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Generate a synthetic threshold function by interpolation of real data.
Real data is from Dubno et al. 2013, and procedure follows Song et al. 2017, 2018.
See make_songetal_testfun for more detail.
x (torch.Tensor): Frequency
y (torch.Tensor): Threshold
Callable[[torch.Tensor], torch.Tensor]: Function that interpolates the given
frequencies and thresholds and returns threshold as a function
of frequency.
x_np = x.cpu().numpy()
y_np = y.cpu().numpy()
# These are not directly implemented in pytorch, so we use scipy for now
f_interp = CubicSpline(x_np, y_np, extrapolate=False)
f_extrap = interp1d(x_np, y_np, fill_value="extrapolate")
def f_combo(x):
x_np = x.cpu().numpy()
interpolated = f_interp(x_np)
interpolated[np.isnan(interpolated)] = f_extrap(x_np[np.isnan(interpolated)])
return torch.from_numpy(interpolated)
return f_combo
[docs]def make_songetal_testfun(
phenotype: str = "Metabolic", beta: float = 1
) -> Callable[[torch.Tensor, bool], torch.Tensor]:
"""Make an audiometric test function following Song et al. 2017.
To do so, we first compute a threshold by interpolation/extrapolation
from real data, then assume a linear psychometric function in intensity
with slope beta.
phenotype (str, optional): Audiometric phenotype from Dubno et al. 2013.
Specifically, one of "Metabolic", "Sensory", "Metabolic+Sensory",
or "Older-normal". Defaults to "Metabolic".
beta (float, optional): Psychometric function slope. Defaults to 1.
Callable[[torch.Tensor, bool], torch.Tensor]: A test function taking a [b x 2] tensor of points and returning the psychometric function value at those points.
AssertionError: if an invalid phenotype is passed.
Song, X. D., Garnett, R., & Barbour, D. L. (2017).
Psychometric function estimation by probabilistic classification.
The Journal of the Acoustical Society of America, 141(4), 2513–2525.
valid_phenotypes = ["Metabolic", "Sensory", "Metabolic+Sensory", "Older-normal"]
assert phenotype in valid_phenotypes, f"Phenotype must be one of {valid_phenotypes}"
x = torch.tensor(
dubno_data[dubno_data.phenotype == phenotype].freq.values, dtype=torch.float64
y = torch.tensor(
dubno_data[dubno_data.phenotype == phenotype].thresh.values, dtype=torch.float64
# first, make the threshold fun
threshfun = make_songetal_threshfun(x, y)
# now make it into a test function
def song_testfun(x, cdf=False):
logfreq = x[..., 0]
intensity = x[..., 1]
thresh = threshfun(2**logfreq)
return (
torch.distributions.Normal(0, 1).cdf((intensity - thresh) / beta)
if cdf
else (intensity - thresh) / beta
return song_testfun
[docs]def novel_discrimination_testfun(x: torch.Tensor) -> torch.Tensor:
"""Evaluate novel discrimination test function from Owen et al.
The threshold is roughly parabolic with context, and the slope
varies with the threshold. Adding to the difficulty is the fact
that the function is minimized at f=0 (or p=0.5), corresponding
to discrimination being at chance at zero stimulus intensity.
x (torch.Tensor): Points at which to evaluate.
torch.Tensor: Value of function at these points.
freq = x[..., 0]
amp = x[..., 1]
context = 2 * (0.05 + 0.4 * (-1 + 0.2 * freq) ** 2 * freq**2)
return 2 * (amp + 1) / context
[docs]def novel_detection_testfun(x: torch.Tensor) -> torch.Tensor:
"""Evaluate novel detection test function from Owen et al.
The threshold is roughly parabolic with context, and the slope
varies with the threshold.
x (torch.Tensor): Points at which to evaluate.
torch.Tensor: Value of function at these points.
freq = x[..., 0]
amp = x[..., 1]
context = 2 * (0.05 + 0.4 * (-1 + 0.2 * freq) ** 2 * freq**2)
return 4 * (amp + 1) / context - 4
[docs]def discrim_highdim(x: torch.Tensor) -> torch.Tensor:
amp = x[..., 0]
freq = x[..., 1]
vscale = x[..., 2]
vshift = x[..., 3]
variance = x[..., 4]
asym = x[..., 5]
phase = x[..., 6]
period = x[..., 7]
context = (
-0.5 * vscale * torch.cos(period * 0.6 * torch.pi * freq + phase)
+ vscale / 2
+ vshift
) * (
-1 * asym * torch.sin(period * 0.6 * torch.pi * 0.5 * freq + phase) + (2 - asym)
) - 1
z = (amp - context) / (variance + variance * (1 + context))
normal_dist = torch.distributions.Normal(0, 1)
p = normal_dist.cdf(z)
p = (1 - 0.5) * p + 0.5
p = torch.clamp(p, 0.5, 1 - 1e-5)
return normal_dist.icdf(p)
[docs]def modified_hartmann6(X: torch.Tensor) -> torch.Tensor:
The modified Hartmann6 function used in Lyu et al.
C = torch.tensor([0.2, 0.22, 0.28, 0.3], dtype=torch.float64)
a_t = torch.tensor(
[8, 3, 10, 3.5, 1.7, 6],
[0.5, 8, 10, 1.0, 6, 9],
[3, 3.5, 1.7, 8, 10, 6],
[10, 6, 0.5, 8, 1.0, 9],
p_t = 10 ** (-4) * torch.tensor(
[1312, 1696, 5569, 124, 8283, 5886],
[2329, 4135, 8307, 3736, 1004, 9991],
[2348, 1451, 3522, 2883, 3047, 6650],
[4047, 8828, 8732, 5743, 1091, 381],
y = torch.tensor(0.0, dtype=torch.float64)
for i, C_i in enumerate(C):
t = torch.tensor(0.0, dtype=torch.float64)
for j in range(6):
t += a_t[i, j] * ((X[j] - p_t[i, j]) ** 2)
y += C_i * torch.exp(-t)
return -10 * (y - 0.1)
[docs]def f_1d(x: torch.Tensor, mu: float = 0) -> torch.Tensor:
latent is just a gaussian bump at mu
return torch.exp(-((x - mu) ** 2))
[docs]def f_2d(x: torch.Tensor) -> torch.Tensor:
a gaussian bump at 0, 0
return torch.exp(-torch.norm(x, dim=-1))
[docs]def new_novel_det_params(
freq: torch.Tensor, scale_factor: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the loc and scale params for 2D synthetic novel_det(frequency) function
Keyword arguments:
freq -- 1D tensor of frequencies whose thresholds to return
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
locs = 0.66 * torch.pow(0.8 * freq * (0.2 * freq - 1), 2) + 0.05
scale = 2 * locs / (3 * scale_factor)
loc = -1 + 2 * locs
return loc, scale
[docs]def target_new_novel_det(
freq: torch.Tensor, scale_factor: float = 1.0, target: float = 0.75
) -> torch.Tensor:
"""Get the target (i.e. threshold) for 2D synthetic novel_det(frequency) function
Keyword arguments:
freq -- 1D tensor of frequencies whose thresholds to return
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
target -- target threshold
locs, scale = new_novel_det_params(freq, scale_factor)
normal_dist = torch.distributions.Normal(locs, scale)
return normal_dist.icdf(torch.tensor(target))
[docs]def new_novel_det(x: torch.Tensor, scale_factor: float = 1.0) -> torch.Tensor:
"""Get the cdf for 2D synthetic novel_det(frequency) function
Keyword arguments:
x -- tensor of shape (n,2) of locations to sample;
x[...,0] is frequency from -1 to 1; x[...,1] is intensity from -1 to 1
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
freq = x[..., 0]
locs, scale = new_novel_det_params(freq, scale_factor)
return (x[..., 1] - locs) / scale
[docs]def cdf_new_novel_det(x: torch.Tensor, scale_factor: float = 1.0) -> torch.Tensor:
"""Get the cdf for 2D synthetic novel_det(frequency) function
Keyword arguments:
x -- tensor of shape (n,2) of locations to sample;
x[...,0] is frequency from -1 to 1; x[...,1] is intensity from -1 to 1
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
z = new_novel_det(x, scale_factor)
normal_dist = torch.distributions.Normal(0, 1) # Standard normal distribution
return normal_dist.cdf(z)
[docs]def new_novel_det_channels_params(
channel: torch.Tensor,
scale_factor: float = 1.0,
wave_freq: float = 1,
target: float = 0.75,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the target parameters for 2D synthetic novel_det(channel) function
Keyword arguments:
channel -- 1D tensor of channel locations whose thresholds to return
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
wave_freq -- frequency of location waveform on [-1,1]
target -- target threshold
locs = -0.3 * torch.sin(5 * wave_freq * (channel - 1 / 6) / torch.pi) ** 2 - 0.5
scale = (
/ (10 * scale_factor)
* (0.75 + 0.25 * torch.cos(10 * (0.3 + channel) / torch.pi))
return locs, scale
[docs]def target_new_novel_det_channels(
channel: torch.Tensor,
scale_factor: float = 1.0,
wave_freq: float = 1,
target: float = 0.75,
) -> torch.Tensor:
"""Get the target (i.e. threshold) for 2D synthetic novel_det(channel) function
Keyword arguments:
channel -- 1D tensor of channel locations whose thresholds to return
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
wave_freq -- frequency of location waveform on [-1,1]
target -- target threshold
locs, scale = new_novel_det_channels_params(
channel, scale_factor, wave_freq, target
normal_dist = torch.distributions.Normal(locs, scale)
return normal_dist.icdf(torch.tensor(target))
[docs]def new_novel_det_channels(
x: torch.Tensor,
channel: torch.Tensor,
scale_factor: float = 1.0,
wave_freq: float = 1,
target: float = 0.75,
) -> torch.Tensor:
"""Get the 2D synthetic novel_det(channel) function
Keyword arguments:
x -- tensor of shape (n,2) of locations to sample;
x[...,0] is channel from -1 to 1; x[...,1] is intensity from -1 to 1
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
wave_freq -- frequency of location waveform on [-1,1]
locs, scale = new_novel_det_channels_params(
channel, scale_factor, wave_freq, target
return (x[..., 1] - locs) / scale
[docs]def cdf_new_novel_det_channels(
x: torch.Tensor,
channel: torch.Tensor,
scale_factor: float = 1.0,
wave_freq: float = 1,
target: float = 0.75,
) -> torch.Tensor:
"""Get the cdf for 2D synthetic novel_det(channel) function
Keyword arguments:
x -- tensor of shape (n,2) of locations to sample;
x[...,0] is channel from -1 to 1; x[...,1] is intensity from -1 to 1
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
wave_freq -- frequency of location waveform on [-1,1]
z = new_novel_det_channels(x, channel, scale_factor, wave_freq, target)
normal_dist = torch.distributions.Normal(0, 1) # Standard normal distribution
return normal_dist.cdf(z)
[docs]def new_novel_det_3D_params(
x: torch.Tensor, scale_factor: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
freq = x[..., 0]
chan = x[..., 1]
locs_freq = -0.32 + 2 * (0.66 * torch.pow(0.8 * freq * (0.2 * freq - 1), 2) + 0.05)
locs = (
0.7 * ((-0.35 * torch.sin(5 * (chan - 1 / 6) / torch.pi) ** 2) - 0.5)
+ 0.9 * locs_freq
scale = 0.3 * locs / (3 * scale_factor) * 1 / (10 * scale_factor) + 0.15 * (
0.75 + 0.25 * torch.cos(10 * (0.6 + chan) / torch.pi)
return locs, scale
[docs]def new_novel_det_3D(x: torch.Tensor, scale_factor: float = 1.0) -> torch.Tensor:
Get the synthetic 3D novel_det function over freqs, channels, and amplitudes.
locs, scale = new_novel_det_3D_params(x, scale_factor)
return (x[..., 2] - locs) / scale
[docs]def cdf_new_novel_det_3D(x: torch.Tensor, scale_factor: float = 1.0) -> torch.Tensor:
Get the cdf for 3D synthetic novel_det function
x -- tensor of shape (n,3) of locations to sample
x[...,0] is frequency, x[...,1] is channel, x[...,2] is intensity
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
z = new_novel_det_3D(x, scale_factor)
normal_dist = torch.distributions.Normal(0, 1) # Standard normal distribution
return normal_dist.cdf(z)
[docs]def target_new_novel_det_3D(
x: torch.Tensor, scale_factor: float = 1.0, target: float = 0.75
) -> torch.Tensor:
Get target for 3D synthetic novel_det function at location x
x -- tensor of shape (n,2) of locations to sample
x[...,0] is frequency, x[...,1] is channel,
scale factor -- scale for the novel_det function, where higher is steeper/lower SD
target -- target threshold
locs, scale = new_novel_det_3D_params(x, scale_factor)
normal_dist = torch.distributions.Normal(locs, scale)
return normal_dist.icdf(torch.tensor(target))
[docs]def f_pairwise(f: Callable, x: torch.Tensor, noise_scale: float = 1) -> torch.Tensor:
normal_dist = torch.distributions.Normal(0, 1)
return normal_dist.cdf(
(f(x[..., 1]) - f(x[..., 0])) / (noise_scale * torch.sqrt(torch.tensor(2.0)))