Source code for aepsych.acquisition.rejection_sampler

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import torch
from botorch.posteriors import Posterior
from botorch.sampling.base import MCSampler


[docs]class RejectionSampler(MCSampler): """ Samples from a posterior subject to the constraint that samples in constrained_idx should be >= 0. If not enough feasible samples are generated, will return the least violating samples. """ def __init__( self, num_samples: int, num_rejection_samples: int, constrained_idx: torch.Tensor ): """Initialize RejectionSampler Args: num_samples (int): Number of samples to return. Note that if fewer samples than this number are positive in the required dimension, the remaining samples returned will be the "least violating", i.e. closest to 0. num_rejection_samples (int): Number of samples to draw before rejecting. constrained_idx (torch.Tensor): Indices of input dimensions that should be constrained positive. """ self.num_samples = num_samples self.num_rejection_samples = num_rejection_samples self.constrained_idx = constrained_idx super().__init__(sample_shape=torch.Size([num_samples]))
[docs] def forward(self, posterior: Posterior) -> torch.Tensor: """Run the rejection sampler. Args: posterior (Posterior): The unconstrained GP posterior object to perform rejection samples on. Returns: torch.Tensor: Kept samples. """ samples = posterior.rsample( sample_shape=torch.Size([self.num_rejection_samples]) ) assert ( samples.shape[-1] == 1 ), "Batches not supported" # TODO T68656582 handle batches later constrained_samps = samples[:, self.constrained_idx, 0] valid = (constrained_samps >= 0).all(dim=1) if valid.sum() < self.num_samples: worst_violation = constrained_samps.min(dim=1)[0] keep = torch.argsort(worst_violation, descending=True)[: self.num_samples] else: keep = torch.where(valid)[0][: self.num_samples] return samples[keep, :, :]