Source code for aepsych.kernels.pairwisekernel

import torch
from gpytorch.kernels import Kernel
from linear_operator import to_linear_operator


[docs]class PairwiseKernel(Kernel): """ Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K). Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K') where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d). """ def __init__(self, latent_kernel, is_partial_obs=False, **kwargs): super(PairwiseKernel, self).__init__(**kwargs) self.latent_kernel = latent_kernel self.is_partial_obs = is_partial_obs
[docs] def forward(self, x1, x2, diag=False, **params): r""" TODO: make last_batch_dim work properly d must be 2*k for integer k, k is the dimension of the latent space Args: :attr:`x1` (Tensor `n x d` or `b x n x d`): First set of data :attr:`x2` (Tensor `m x d` or `b x m x d`): Second set of data :attr:`diag` (bool): Should the Kernel compute the whole kernel, or just the diag? Returns: :class:`Tensor` or :class:`gpytorch.lazy.LazyTensor`. The exact size depends on the kernel's evaluation mode: * `full_covar`: `n x m` or `b x n x m` * `diag`: `n` or `b x n` """ if self.is_partial_obs: d = x1.shape[-1] - 1 assert d == x2.shape[-1] - 1, "tensors not the same dimension" assert d % 2 == 0, "dimension must be even" k = int(d / 2) # special handling for kernels that (also) do funky # things with the input dimension deriv_idx_1 = x1[..., -1][:, None] deriv_idx_2 = x2[..., -1][:, None] a = torch.cat((x1[..., :k], deriv_idx_1), dim=1) b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1) c = torch.cat((x2[..., :k], deriv_idx_2), dim=1) d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1) else: d = x1.shape[-1] assert d == x2.shape[-1], "tensors not the same dimension" assert d % 2 == 0, "dimension must be even" k = int(d / 2) a = x1[..., :k] b = x1[..., k:] c = x2[..., :k] d = x2[..., k:] if not diag: return ( to_linear_operator(self.latent_kernel(a, c, diag=diag, **params)) + to_linear_operator(self.latent_kernel(b, d, diag=diag, **params)) - to_linear_operator(self.latent_kernel(b, c, diag=diag, **params)) - to_linear_operator(self.latent_kernel(a, d, diag=diag, **params)) ) else: return ( self.latent_kernel(a, c, diag=diag, **params) + self.latent_kernel(b, d, diag=diag, **params) - self.latent_kernel(b, c, diag=diag, **params) - self.latent_kernel(a, d, diag=diag, **params) )