diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index e6b684fab..9a5217cef 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -132,7 +132,7 @@ def sample( ) samples = rejection.accept_reject_sample( - proposal=self.posterior_estimator, + proposal=self.posterior_estimator.sample, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, show_progress_bars=show_progress_bars, @@ -176,7 +176,7 @@ def sample_batched( ) samples = rejection.accept_reject_sample( - proposal=self.posterior_estimator, + proposal=self.posterior_estimator.sample, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_samples, show_progress_bars=show_progress_bars, @@ -373,7 +373,7 @@ def leakage_correction( def acceptance_at(x: Tensor) -> Tensor: # [1:] to remove batch-dimension for `reshape_to_batch_event`. return rejection.accept_reject_sample( - proposal=self.posterior_estimator, + proposal=self.posterior_estimator.sample, accept_reject_fn=lambda theta: within_support(self.prior, theta), num_samples=num_rejection_samples, show_progress_bars=show_progress_bars, diff --git a/sbi/inference/posteriors/score_posterior.py b/sbi/inference/posteriors/score_posterior.py index eb1a9fa09..0d0c60d4d 100644 --- a/sbi/inference/posteriors/score_posterior.py +++ b/sbi/inference/posteriors/score_posterior.py @@ -9,6 +9,7 @@ from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.potentials.score_based_potential import ( + CallableDifferentiablePotentialFunction, PosteriorScoreBasedPotential, score_estimator_based_potential, ) @@ -16,11 +17,13 @@ from sbi.neural_nets.estimators.shape_handling import ( reshape_to_batch_event, ) +from sbi.samplers.rejection import rejection from sbi.samplers.score.correctors import Corrector from sbi.samplers.score.diffuser import Diffuser from sbi.samplers.score.predictors import Predictor from sbi.sbi_types import Shape from sbi.utils import check_prior +from sbi.utils.sbiutils import gradient_ascent, within_support from sbi.utils.torchutils import ensure_theta_batched @@ -46,7 +49,7 @@ def __init__( prior: Distribution, max_sampling_batch_size: int = 10_000, device: Optional[str] = None, - enable_transform: bool = False, + enable_transform: bool = True, sample_with: str = "sde", ): """ @@ -110,7 +113,6 @@ def sample( Args: sample_shape: Shape of the samples to be drawn. - x: Deprecated - use `.set_default_x()` prior to `.sample()`. predictor: The predictor for the diffusion-based sampler. Can be a string or a custom predictor following the API in `sbi.samplers.score.predictors`. Currently, only `euler_maruyama` is implemented. @@ -136,23 +138,39 @@ def sample( x = self._x_else_default_x(x) x = reshape_to_batch_event(x, self.score_estimator.condition_shape) - self.potential_fn.set_x(x) + self.potential_fn.set_x(x, x_is_iid=True) + + num_samples = torch.Size(sample_shape).numel() if self.sample_with == "ode": - samples = self.sample_via_zuko(sample_shape=sample_shape, x=x) - elif self.sample_with == "sde": - samples = self._sample_via_diffusion( - sample_shape=sample_shape, - predictor=predictor, - corrector=corrector, - predictor_params=predictor_params, - corrector_params=corrector_params, - steps=steps, - ts=ts, + samples = rejection.accept_reject_sample( + proposal=self.sample_via_ode, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + show_progress_bars=show_progress_bars, max_sampling_batch_size=max_sampling_batch_size, + )[0] + elif self.sample_with == "sde": + proposal_sampling_kwargs = { + "predictor": predictor, + "corrector": corrector, + "predictor_params": predictor_params, + "corrector_params": corrector_params, + "steps": steps, + "ts": ts, + "max_sampling_batch_size": max_sampling_batch_size, + "show_progress_bars": show_progress_bars, + } + samples = rejection.accept_reject_sample( + proposal=self._sample_via_diffusion, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, show_progress_bars=show_progress_bars, - ) + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs=proposal_sampling_kwargs, + )[0] + samples = samples.reshape(sample_shape + self.score_estimator.input_shape) return samples def _sample_via_diffusion( @@ -171,7 +189,6 @@ def _sample_via_diffusion( Args: sample_shape: Shape of the samples to be drawn. - x: Deprecated - use `.set_default_x()` prior to `.sample()`. predictor: The predictor for the diffusion-based sampler. Can be a string or a custom predictor following the API in `sbi.samplers.score.predictors`. Currently, only `euler_maruyama` is implemented. @@ -222,11 +239,10 @@ def _sample_via_diffusion( ) samples = torch.cat(samples, dim=0)[:num_samples] - return samples.reshape(sample_shape + self.score_estimator.input_shape) + return samples - def sample_via_zuko( + def sample_via_ode( self, - x: Tensor, sample_shape: Shape = torch.Size(), ) -> Tensor: r"""Return samples from posterior distribution with probability flow ODE. @@ -243,10 +259,12 @@ def sample_via_zuko( """ num_samples = torch.Size(sample_shape).numel() - flow = self.potential_fn.get_continuous_normalizing_flow(condition=x) + flow = self.potential_fn.get_continuous_normalizing_flow( + condition=self.potential_fn.x_o + ) samples = flow.sample(torch.Size((num_samples,))) - return samples.reshape(sample_shape + self.score_estimator.input_shape) + return samples def log_prob( self, @@ -291,19 +309,73 @@ def sample_batched( self, sample_shape: torch.Size, x: Tensor, + predictor: Union[str, Predictor] = "euler_maruyama", + corrector: Optional[Union[str, Corrector]] = None, + predictor_params: Optional[Dict] = None, + corrector_params: Optional[Dict] = None, + steps: int = 500, + ts: Optional[Tensor] = None, max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - raise NotImplementedError( - "Batched sampling is not implemented for ScorePosterior." + num_samples = torch.Size(sample_shape).numel() + x = reshape_to_batch_event(x, self.score_estimator.condition_shape) + condition_dim = len(self.score_estimator.condition_shape) + batch_shape = x.shape[:-condition_dim] + batch_size = batch_shape.numel() + self.potential_fn.set_x(x) + + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size ) + if self.sample_with == "ode": + samples = rejection.accept_reject_sample( + proposal=self.sample_via_ode, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + num_xos=batch_size, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"x": x}, + )[0] + samples = samples.reshape( + sample_shape + batch_shape + self.score_estimator.input_shape + ) + elif self.sample_with == "sde": + proposal_sampling_kwargs = { + "predictor": predictor, + "corrector": corrector, + "predictor_params": predictor_params, + "corrector_params": corrector_params, + "steps": steps, + "ts": ts, + "max_sampling_batch_size": max_sampling_batch_size, + "show_progress_bars": show_progress_bars, + } + samples = rejection.accept_reject_sample( + proposal=self._sample_via_diffusion, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + num_xos=batch_size, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs=proposal_sampling_kwargs, + )[0] + samples = samples.reshape( + sample_shape + batch_shape + self.score_estimator.input_shape + ) + + return samples + def map( self, x: Optional[Tensor] = None, num_iter: int = 1000, num_to_optimize: int = 1000, - learning_rate: float = 1e-5, + learning_rate: float = 0.01, init_method: Union[str, Tensor] = "posterior", num_init_samples: int = 1000, save_best_every: int = 1000, @@ -351,17 +423,41 @@ def map( Returns: The MAP estimate. """ - raise NotImplementedError( - "MAP estimation is currently not working accurately for ScorePosterior." - ) - return super().map( - x=x, - num_iter=num_iter, - num_to_optimize=num_to_optimize, - learning_rate=learning_rate, - init_method=init_method, - num_init_samples=num_init_samples, - save_best_every=save_best_every, - show_progress_bars=show_progress_bars, - force_update=force_update, - ) + if x is not None: + raise ValueError( + "Passing `x` directly to `.map()` has been deprecated." + "Use `.self_default_x()` to set `x`, and then run `.map()` " + ) + + if self.default_x is None: + raise ValueError( + "Default `x` has not been set." + "To set the default, use the `.set_default_x()` method." + ) + + if self._map is None or force_update: + self.potential_fn.set_x(self.default_x) + callable_potential_fn = CallableDifferentiablePotentialFunction( + self.potential_fn + ) + if init_method == "posterior": + inits = self.sample((num_init_samples,)) + elif init_method == "proposal": + inits = self.proposal.sample((num_init_samples,)) # type: ignore + elif isinstance(init_method, Tensor): + inits = init_method + else: + raise ValueError + + self._map = gradient_ascent( + potential_fn=callable_potential_fn, + inits=inits, + theta_transform=self.theta_transform, + num_iter=num_iter, + num_to_optimize=num_to_optimize, + learning_rate=learning_rate, + save_best_every=save_best_every, + show_progress_bars=show_progress_bars, + )[0] + + return self._map diff --git a/sbi/inference/potentials/score_based_potential.py b/sbi/inference/potentials/score_based_potential.py index 51c37dbf5..08561eb22 100644 --- a/sbi/inference/potentials/score_based_potential.py +++ b/sbi/inference/potentials/score_based_potential.py @@ -1,6 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +from functools import partial from typing import Optional, Tuple import torch @@ -24,7 +25,7 @@ def score_estimator_based_potential( score_estimator: ConditionalScoreEstimator, prior: Optional[Distribution], x_o: Optional[Tensor], - enable_transform: bool = False, + enable_transform: bool = True, ) -> Tuple["PosteriorScoreBasedPotential", TorchTransform]: r"""Returns the potential function gradient for score estimators. @@ -40,10 +41,6 @@ def score_estimator_based_potential( score_estimator, prior, x_o, device=device ) - assert enable_transform is False, ( - "Transforms are not yet supported for score estimators." - ) - if prior is not None: theta_transform = mcmc_transform( prior, device=device, enable_transform=enable_transform @@ -73,16 +70,37 @@ def __init__( `iid_bridge` as proposed in Geffner et al. is implemented. device: The device on which to evaluate the potential. """ - - super().__init__(prior, x_o, device=device) self.score_estimator = score_estimator self.score_estimator.eval() self.iid_method = iid_method + super().__init__(prior, x_o, device=device) + + def set_x( + self, + x_o: Optional[Tensor], + x_is_iid: Optional[bool] = False, + rebuild_flow: Optional[bool] = True, + ): + """ + Set the observed data and whether it is IID. + Args: + x_o: The observed data. + x_is_iid: Whether the observed data is IID (if batch_dim>1). + rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if + the flow needs to be evaluated many times (e.g. for MAP calculation). + """ + super().set_x(x_o, x_is_iid) + if rebuild_flow and self._x_o is not None: + # By default, we want a high-tolerance flow. + # This flow will be used mainly for MAP calculations, hence we want to save + # it instead of rebuilding it every time. + self.flow = self.rebuild_flow(atol=1e-2, rtol=1e-3, exact=True) def __call__( self, theta: Tensor, track_gradients: bool = True, + rebuild_flow: bool = True, atol: float = 1e-5, rtol: float = 1e-6, exact: bool = True, @@ -92,6 +110,7 @@ def __call__( Args: theta: The parameters at which to evaluate the potential. track_gradients: Whether to track gradients. + rebuild_flow: Whether to rebuild the CNF for accurate log_prob evaluation. atol: Absolute tolerance for the ODE solver. rtol: Relative tolerance for the ODE solver. exact: Whether to use the exact ODE solver. @@ -103,18 +122,13 @@ def __call__( theta_density_estimator = reshape_to_sample_batch_event( theta, theta.shape[1:], leading_is_sample=True ) - x_density_estimator = reshape_to_batch_event( - self.x_o, event_shape=self.score_estimator.condition_shape - ) - assert x_density_estimator.shape[0] == 1, ( - "PosteriorScoreBasedPotential supports only x batchsize of 1`." - ) - self.score_estimator.eval() - - flow = self.get_continuous_normalizing_flow( - condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact - ) + # use rebuild_flow to evaluate log_prob with better precision, without + # overwriting self.flow + if rebuild_flow or self.flow is None: + flow = self.rebuild_flow(atol=atol, rtol=rtol, exact=exact) + else: + flow = self.flow with torch.set_grad_enabled(track_gradients): log_probs = flow.log_prob(theta_density_estimator).squeeze(-1) @@ -134,7 +148,7 @@ def gradient( r"""Returns the potential function gradient for score-based methods. Args: - theta: The parameters at which to evaluate the potential. + theta: The parameters at which to evaluate the potential gradient. time: The diffusion time. If None, then `t_min` of the self.score_estimator is used (i.e. we evaluate the gradient of the actual data distribution). @@ -188,12 +202,36 @@ def get_continuous_normalizing_flow( # Use zuko to build the normalizing flow. return NormalizingFlow(transform, base=base_density) + def rebuild_flow( + self, atol: float = 1e-5, rtol: float = 1e-6, exact: bool = True + ) -> NormalizingFlow: + """ + Rebuilds the continuous normalizing flow. This is used when + a new default x is set, or to evaluate the log probs at higher precision. + """ + if self._x_o is None: + raise ValueError( + "No observed data x_o is available. Please reinitialize \ + the potential or manually set self._x_o." + ) + x_density_estimator = reshape_to_batch_event( + self.x_o, event_shape=self.score_estimator.condition_shape + ) + assert x_density_estimator.shape[0] == 1, ( + "PosteriorScoreBasedPotential supports only x batchsize of 1`." + ) + + flow = self.get_continuous_normalizing_flow( + condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact + ) + return flow + def build_freeform_jacobian_transform( score_estimator: ConditionalScoreEstimator, x_o: Tensor, - atol: float = 1e-5, - rtol: float = 1e-6, + atol: float = 1e-6, + rtol: float = 1e-5, exact: bool = True, ) -> FreeFormJacobianTransform: """Builds the free-form Jacobian for the probability flow ODE, used for log-prob. @@ -228,3 +266,55 @@ def f(t, x): exact=exact, ) return transform + + +class DifferentiablePotentialFunction(torch.autograd.Function): + """ + A wrapper of PosteriorScoreBasedPotential with a custom autograd function to compute + the gradient of log_prob with respect to theta. Instead of backpropagating through + the continuous normalizing flow, we use the gradient of the score estimator. + + """ + + @staticmethod + def forward(ctx, input, call_function, gradient_function): + """ + Computes the potential normally. + """ + # Save the methods as callables + ctx.call_function = call_function + ctx.gradient_function = gradient_function + ctx.save_for_backward(input) + + # Perform the forward computation + output = call_function(input) + return output + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + grad = ctx.gradient_function(input) + # Match dims + while len(grad_output.shape) < len(grad.shape): + grad_output = grad_output.unsqueeze(-1) + grad_input = grad_output * grad + return grad_input, None, None + + +class CallableDifferentiablePotentialFunction: + """ + This class handles the forward and backward functions from the potential function + that can be passed to DifferentiablePotentialFunction, as torch.autograd.Function + only supports static methods, and so it can't be given the potential class directly. + """ + + def __init__(self, posterior_score_based_potential): + self.posterior_score_based_potential = posterior_score_based_potential + + def __call__(self, input): + prepared_potential = partial( + self.posterior_score_based_potential.__call__, rebuild_flow=False + ) + return DifferentiablePotentialFunction.apply( + input, prepared_potential, self.posterior_score_based_potential.gradient + ) diff --git a/sbi/inference/trainers/npse/npse.py b/sbi/inference/trainers/npse/npse.py index 17a8181e0..c8b647e2e 100644 --- a/sbi/inference/trainers/npse/npse.py +++ b/sbi/inference/trainers/npse/npse.py @@ -179,11 +179,12 @@ def train( training_batch_size: int = 200, learning_rate: float = 5e-4, validation_fraction: float = 0.1, - stop_after_epochs: int = 200, + stop_after_epochs: int = 50, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, calibration_kernel: Optional[Callable] = None, ema_loss_decay: float = 0.1, + validation_times: Union[Tensor, int] = 20, resume_training: bool = False, force_first_round_loss: bool = False, discard_prior_samples: bool = False, @@ -192,7 +193,11 @@ def train( dataloader_kwargs: Optional[dict] = None, ) -> ConditionalScoreEstimator: r"""Returns a score estimator that approximates the score - $\nabla_\theta \log p(\theta|x)$. + $\nabla_\theta \log p(\theta|x)$. The denoising score matching loss has a high + variance, which makes it more difficult to detect converegence. To reduce this + variance, we evaluate the validation loss at a fixed set of times. We also use + the exponential moving average of the training and validation losses, as opposed + to the other `trainer` classes, which track the loss directly. Args: training_batch_size: Training batch size. @@ -208,6 +213,10 @@ def train( calibration_kernel: A function to calibrate the loss with respect to the simulations `x` (optional). See Lueckmann, Gonçalves et al., NeurIPS 2017. If `None`, no calibration is used. + ema_loss_decay: Loss decay strength for exponential moving average of + training and validation losses. + validation_times: Diffusion times at which to evaluate the validation loss + to reduce variance of validation loss. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will @@ -294,6 +303,14 @@ def default_calibration_kernel(x): # Move entire net to device for training. self._neural_net.to(self._device) + if isinstance(validation_times, int): + validation_times = torch.linspace( + self._neural_net.t_min, self._neural_net.t_max, validation_times + ) + assert isinstance( + validation_times, Tensor + ) # let pyright know validation_times is a Tensor. + if not resume_training: self.optimizer = Adam(list(self._neural_net.parameters()), lr=learning_rate) @@ -316,11 +333,11 @@ def default_calibration_kernel(x): ) train_losses = self._loss( - theta_batch, - x_batch, - masks_batch, - proposal, - calibration_kernel, + theta=theta_batch, + x=x_batch, + masks=masks_batch, + proposal=proposal, + calibration_kernel=calibration_kernel, force_first_round_loss=force_first_round_loss, ) @@ -363,24 +380,43 @@ def default_calibration_kernel(x): batch[1].to(self._device), batch[2].to(self._device), ) + + # For validation loss, we evaluate at a fixed set of times to reduce + # the variance in the validation loss, for improved convergence + # checks. We evaluate the entire validation batch at all times, so + # we repeat the batches here to match. + val_batch_size = theta_batch.shape[0] + times_batch = validation_times.shape[0] + theta_batch = theta_batch.repeat( + times_batch, *([1] * (theta_batch.ndim - 1)) + ) + x_batch = x_batch.repeat(times_batch, *([1] * (x_batch.ndim - 1))) + masks_batch = masks_batch.repeat( + times_batch, *([1] * (masks_batch.ndim - 1)) + ) + + validation_times_rep = validation_times.repeat_interleave( + val_batch_size, dim=0 + ) + # Take negative loss here to get validation log_prob. val_losses = self._loss( - theta_batch, - x_batch, - masks_batch, - proposal, - calibration_kernel, + theta=theta_batch, + x=x_batch, + masks=masks_batch, + proposal=proposal, + calibration_kernel=calibration_kernel, + times=validation_times_rep, force_first_round_loss=force_first_round_loss, ) + val_loss_sum += val_losses.sum().item() # Take mean over all validation samples. val_loss = val_loss_sum / ( - len(val_loader) * val_loader.batch_size # type: ignore + len(val_loader) * val_loader.batch_size * times_batch # type: ignore ) - # NOTE: Due to the inherently noisy nature we do instead log a exponential - # moving average of the validation loss. if len(self._summary["validation_loss"]) == 0: val_loss_ema = val_loss else: @@ -489,6 +525,7 @@ def _loss( masks: Tensor, proposal: Optional[Any], calibration_kernel: Callable, + times: Optional[Tensor] = None, force_first_round_loss: bool = False, ) -> Tensor: """Return loss from score estimator. Currently only single-round NPSE @@ -505,7 +542,7 @@ def _loss( """ if self._round == 0 or force_first_round_loss: # First round loss. - loss = self._neural_net.loss(theta, x) + loss = self._neural_net.loss(theta, x, times) else: raise NotImplementedError( "Multi-round NPSE with arbitrary proposals is not implemented" @@ -513,38 +550,3 @@ def _loss( assert_all_finite(loss, "NPSE loss") return calibration_kernel(x) * loss - - def _converged(self, epoch: int, stop_after_epochs: int) -> bool: - """Check if training has converged. - - Unlike the `._converged` method in base.py, this method does not reset to the - best model. We noticed that this improves performance. Deleting this method - will make C2ST tests fail. This is because the loss is very stochastic, so - resetting might reset to an underfitted model. Ideally, we would write a - custom `._converged()` method which checks whether the loss is still going - down **for all t**. - - Args: - epoch: Current epoch. - stop_after_epochs: Number of epochs to wait for improvement on the - validation set before terminating training. - - Returns: - Whether training has converged. - """ - converged = False - - # No checkpointing, just check if the validation loss has improved. - - # (Re)-start the epoch count with the first epoch or any improvement. - if epoch == 0 or self._val_loss < self._best_val_loss: - self._best_val_loss = self._val_loss - self._epochs_since_last_improvement = 0 - else: - self._epochs_since_last_improvement += 1 - - # If no validation improvement over many epochs, stop training. - if self._epochs_since_last_improvement > stop_after_epochs - 1: - converged = True - - return converged diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 26ac254a3..ae25a69c6 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -1,11 +1,10 @@ import logging import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple import torch import torch.distributions.transforms as torch_tf -from torch import Tensor, as_tensor, nn -from torch.distributions import Distribution +from torch import Tensor, as_tensor from tqdm.auto import tqdm from sbi.utils.sbiutils import gradient_ascent @@ -188,9 +187,10 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: @torch.no_grad() def accept_reject_sample( - proposal: Union[nn.Module, Distribution], + proposal: Callable, accept_reject_fn: Callable, num_samples: int, + num_xos: int = 1, show_progress_bars: bool = False, warn_acceptance: float = 0.01, sample_for_correction_factor: bool = False, @@ -214,11 +214,15 @@ def accept_reject_sample( density during evaluation of the posterior. Args: - posterior_nn: Neural net representing the posterior. - accept_reject_fn: Function that evaluatuates which samples are accepted or + proposal: A callable that takes `sample_shape` as arguments (and kwargs as + needed). Returns samples from the proposal distribution with shape + (*sample_shape, event_dim). + accept_reject_fn: Function that evaluates which samples are accepted or rejected. Must take a batch of parameters and return a boolean tensor which indicates which parameters get accepted. num_samples: Desired number of samples. + num_xos: Number of conditions for batched_sampling (currently only accepting + one batch dimension for the condition). show_progress_bars: Whether to show a progressbar during sampling. warn_acceptance: A minimum acceptance rate under which to warn about slowness. sample_for_correction_factor: True if this function was called by @@ -264,8 +268,6 @@ def accept_reject_sample( # But this would require giving the method the condition_shape explicitly... if "condition" in proposal_sampling_kwargs: num_xos = proposal_sampling_kwargs["condition"].shape[0] - else: - num_xos = 1 accepted = [[] for _ in range(num_xos)] acceptance_rate = torch.full((num_xos,), float("Nan")) @@ -278,7 +280,7 @@ def accept_reject_sample( num_samples_possible = 0 while num_remaining > 0: # Sample and reject. - candidates = proposal.sample( + candidates = proposal( (sampling_batch_size,), # type: ignore **proposal_sampling_kwargs, ) diff --git a/sbi/samplers/score/diffuser.py b/sbi/samplers/score/diffuser.py index 1cf7e4cd2..aec4feac8 100644 --- a/sbi/samplers/score/diffuser.py +++ b/sbi/samplers/score/diffuser.py @@ -99,11 +99,8 @@ def initialize(self, num_samples: int) -> Tensor: # batched sampling setting with a flag. # TODO: this fixes the iid setting shape problems, but iid inference via # iid_bridge is not accurate. - # num_batch = self.batch_shape.numel() - # init_shape = (num_batch, num_samples) + self.input_shape - init_shape = ( - num_samples, - ) + self.input_shape # just use num_samples, not num_batch + num_batch = self.batch_shape.numel() + init_shape = (num_samples, num_batch) + self.input_shape # NOTE: for the IID setting we might need to scale the noise with iid batch # size, as in equation (7) in the paper. eps = torch.randn(init_shape, device=self.device) diff --git a/sbi/utils/restriction_estimator.py b/sbi/utils/restriction_estimator.py index dd8502cc0..a546d9c8e 100644 --- a/sbi/utils/restriction_estimator.py +++ b/sbi/utils/restriction_estimator.py @@ -685,7 +685,7 @@ def sample( if sample_with == "rejection": samples, acceptance_rate = rejection.accept_reject_sample( - proposal=self._prior, + proposal=self._prior.sample, accept_reject_fn=self._accept_reject_fn, num_samples=num_samples, show_progress_bars=show_progress_bars, diff --git a/tests/linearGaussian_npse_test.py b/tests/linearGaussian_npse_test.py index c75da767d..9bdba1c52 100644 --- a/tests/linearGaussian_npse_test.py +++ b/tests/linearGaussian_npse_test.py @@ -159,7 +159,7 @@ def simulator(theta): @pytest.mark.xfail( reason="iid_bridge not working.", - raises=NotImplementedError, + raises=AssertionError, strict=True, match="Score accumulation*", ) @@ -203,10 +203,6 @@ def test_npse_iid_inference(num_trials): @pytest.mark.slow -@pytest.mark.xfail( - raises=NotImplementedError, - reason="MAP optimization via score not working accurately.", -) def test_npse_map(): num_dim = 2 x_o = zeros(num_dim) @@ -234,4 +230,4 @@ def test_npse_map(): map_ = posterior.map(show_progress_bars=True) - assert torch.allclose(map_, gt_posterior.mean, atol=0.2), "MAP is not close to GT." + assert torch.allclose(map_, gt_posterior.mean, atol=0.4), "MAP is not close to GT." diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index ca7d0f103..2532a7760 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -12,6 +12,7 @@ NLE_A, NPE_A, NPE_C, + NPSE, NRE_A, NRE_B, NRE_C, @@ -238,6 +239,76 @@ def test_batched_mcmc_sample_log_prob_with_different_x( ) +@pytest.mark.slow +@pytest.mark.parametrize("npse_method", [NPSE]) +@pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) +@pytest.mark.parametrize("sampling_method", ["sde", "ode"]) +@pytest.mark.parametrize( + "sample_shape", + ( + (5,), # less than num_chains + (4, 2), # 2D batch + ), +) +def test_batched_score_sample_with_different_x( + npse_method: type, + x_o_batch_dim: bool, + sampling_method: str, + sample_shape: torch.Size, +): + num_dim = 2 + num_simulations = 100 + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + simulator = diagonal_linear_gaussian + + inference = npse_method(prior=prior) + theta = prior.sample((num_simulations,)) + x = simulator(theta) + inference.append_simulations(theta, x).train(max_num_epochs=2) + + x_o = ones(num_dim) if x_o_batch_dim == 0 else ones(x_o_batch_dim, num_dim) + + posterior = inference.build_posterior(sample_with=sampling_method) + + samples = posterior.sample_batched( + sample_shape, + x_o, + ) + + assert ( + samples.shape == (*sample_shape, x_o_batch_dim, num_dim) + if x_o_batch_dim > 0 + else (*sample_shape, num_dim) + ), "Sample shape wrong" + + # test only for 1 sample_shape case to avoid repeating this test. + if x_o_batch_dim > 1 and sample_shape == (5,): + assert samples.shape[1] == x_o_batch_dim, "Batch dimension wrong" + inference = npse_method(prior=prior) + _ = inference.append_simulations(theta, x).train() + posterior = inference.build_posterior(sample_with=sampling_method) + + x_o = torch.stack([0.5 * ones(num_dim), -0.5 * ones(num_dim)], dim=0) + # test with multiple chains to test whether correct chains are + # concatenated. + sample_shape = (1000,) # use enough samples for accuracy comparison + samples = posterior.sample_batched(sample_shape, x_o) + + samples_separate1 = posterior.sample(sample_shape, x_o[0]) + samples_separate2 = posterior.sample(sample_shape, x_o[1]) + + # Check if means are approx. same + samples_m = torch.mean(samples, dim=0, dtype=torch.float32) + samples_separate1_m = torch.mean(samples_separate1, dim=0, dtype=torch.float32) + samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32) + samples_sep_m = torch.stack([samples_separate1_m, samples_separate2_m], dim=0) + + assert torch.allclose(samples_m, samples_sep_m, atol=0.2, rtol=0.2), ( + "Batched sampling is not consistent with separate sampling." + ) + + @pytest.mark.slow @pytest.mark.parametrize("density_estimator", ["mdn", "maf", "zuko_nsf"]) def test_batched_sampling_and_logprob_accuracy(density_estimator: str):