Source code for aepsych.likelihoods.ordinal
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
import gpytorch
import torch
from gpytorch.likelihoods import Likelihood
from torch.distributions import Categorical, Normal
[docs]class OrdinalLikelihood(Likelihood):
"""
Ordinal likelihood, suitable for rating models (e.g. likert scales). Formally,
.. math:: z_k(x\\mid f) := p(d_k < f(x) \\le d_{k+1}) = \\sigma(d_{k+1}-f(x)) - \\sigma(d_{k}-f(x)),
where :math:`\\sigma()` is the link function (equivalent to the perceptual noise
distribution in psychophysics terms), :math:`f(x)` is the latent GP evaluated at x,
and :math:`d_k` is a learned cutpoint parameter for each level.
"""
def __init__(self, n_levels: int, link: Optional[Callable] = None):
super().__init__()
self.n_levels = n_levels
self.register_parameter(
name="raw_cutpoint_deltas",
parameter=torch.nn.Parameter(torch.abs(torch.randn(n_levels - 2))),
)
self.register_constraint("raw_cutpoint_deltas", gpytorch.constraints.Positive())
self.link = link or Normal(0, 1).cdf
@property
def cutpoints(self):
cutpoint_deltas = self.raw_cutpoint_deltas_constraint.transform(
self.raw_cutpoint_deltas
)
# for identification, the first cutpoint is 0
return torch.cat((torch.tensor([0]), torch.cumsum(cutpoint_deltas, 0)))
[docs] def forward(self, function_samples, *params, **kwargs):
# this whole thing can probably be some clever batched thing, meh
probs = torch.zeros(*function_samples.size(), self.n_levels)
probs[..., 0] = self.link(self.cutpoints[0] - function_samples)
for i in range(1, self.n_levels - 1):
probs[..., i] = self.link(self.cutpoints[i] - function_samples) - self.link(
self.cutpoints[i - 1] - function_samples
)
probs[..., -1] = 1 - self.link(self.cutpoints[-1] - function_samples)
res = Categorical(probs=probs)
return res
[docs] @classmethod
def from_config(cls, config):
classname = cls.__name__
n_levels = config.getint(classname, "n_levels")
link = config.getobj(classname, "link", fallback=None)
return cls(n_levels=n_levels, link=link)