from typing import Any, Optional, Union
import torch
from gpytorch.kernels import Kernel
from linear_operator import to_linear_operator
[docs]class PairwiseKernel(Kernel):
"""
Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling
functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K).
Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K')
where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d).
"""
def __init__(
self, latent_kernel: Kernel, is_partial_obs: bool = False, **kwargs
) -> None:
"""
Args:
latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP.
is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False.
"""
super(PairwiseKernel, self).__init__(**kwargs)
self.latent_kernel = latent_kernel
self.is_partial_obs = is_partial_obs
[docs] def forward(
self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params
) -> Optional[torch.Tensor]:
r"""
TODO: make last_batch_dim work properly
d must be 2*k for integer k, k is the dimension of the latent space
Args:
x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
diag (bool): Should the Kernel compute the whole covariance matrix or just the diagonal? Defaults to False.
Returns:
torch.Tensor (or :class:`gpytorch.lazy.LazyTensor`) : A `b x n x m` or `n x m` tensor representing
the covariance matrix between `x1` and `x2`.
The exact size depends on the kernel's evaluation mode:
* `full_covar`: `n x m` or `b x n x m`
* `diag`: `n` or `b x n`
"""
if self.is_partial_obs:
d: Union[torch.Tensor, int] = x1.shape[-1] - 1
assert d == x2.shape[-1] - 1, "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"
k = int(d / 2)
# special handling for kernels that (also) do funky
# things with the input dimension
deriv_idx_1 = x1[..., -1][:, None]
deriv_idx_2 = x2[..., -1][:, None]
a = torch.cat((x1[..., :k], deriv_idx_1), dim=1)
b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1)
c = torch.cat((x2[..., :k], deriv_idx_2), dim=1)
d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1)
else:
d = x1.shape[-1]
assert d == x2.shape[-1], "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"
k = int(d / 2)
a = x1[..., :k]
b = x1[..., k:]
c = x2[..., :k]
d = x2[..., k:]
if not diag:
return (
to_linear_operator(self.latent_kernel(a, c, diag=diag, **params))
+ to_linear_operator(self.latent_kernel(b, d, diag=diag, **params))
- to_linear_operator(self.latent_kernel(b, c, diag=diag, **params))
- to_linear_operator(self.latent_kernel(a, d, diag=diag, **params))
)
else:
return (
self.latent_kernel(a, c, diag=diag, **params)
+ self.latent_kernel(b, d, diag=diag, **params)
- self.latent_kernel(b, c, diag=diag, **params)
- self.latent_kernel(a, d, diag=diag, **params)
)