diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 76b9fdf48..e6b684fab 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -104,12 +104,19 @@ def sample( `sbi` v0.17.2 or older. If it is set, we instantly raise an error. show_progress_bars: Whether to show sampling progress monitor. """ - num_samples = torch.Size(sample_shape).numel() x = self._x_else_default_x(x) x = reshape_to_batch_event( x, event_shape=self.posterior_estimator.condition_shape ) + if x.shape[0] > 1: + raise ValueError( + ".sample() supports only `batchsize == 1`. If you intend " + "to sample multiple observations, use `.sample_batched()`. " + "If you intend to sample i.i.d. observations, set up the " + "posterior density estimator with an appropriate permutation " + "invariant embedding net." + ) max_sampling_batch_size = ( self.max_sampling_batch_size @@ -132,7 +139,7 @@ def sample( max_sampling_batch_size=max_sampling_batch_size, proposal_sampling_kwargs={"condition": x}, alternative_method="build_posterior(..., sample_with='mcmc')", - )[0] + )[0] # [0] to return only samples, not acceptance probabilities. return samples[:, 0] # Remove batch dimension. @@ -221,9 +228,14 @@ def log_prob( x_density_estimator = reshape_to_batch_event( x, event_shape=self.posterior_estimator.condition_shape ) - assert ( - x_density_estimator.shape[0] == 1 - ), ".log_prob() supports only `batchsize == 1`." + if x_density_estimator.shape[0] > 1: + raise ValueError( + ".log_prob() supports only `batchsize == 1`. If you intend " + "to evaluate given multiple observations, use `.log_prob_batched()`. " + "If you intend to evaluate given i.i.d. observations, set up the " + "posterior density estimator with an appropriate permutation " + "invariant embedding net." + ) self.posterior_estimator.eval()