Skip to content

Commit

Permalink
raise ValueError on batch observation
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Sep 4, 2024
1 parent a2895b8 commit 6406139
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,19 @@ def sample(
`sbi` v0.17.2 or older. If it is set, we instantly raise an error.
show_progress_bars: Whether to show sampling progress monitor.
"""

num_samples = torch.Size(sample_shape).numel()
x = self._x_else_default_x(x)
x = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
if x.shape[0] > 1:
raise ValueError(
".sample() supports only `batchsize == 1`. If you intend "
"to sample multiple observations, use `.sample_batched()`. "
"If you intend to sample i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)

max_sampling_batch_size = (
self.max_sampling_batch_size
Expand All @@ -132,7 +139,7 @@ def sample(
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
)[0]
)[0] # [0] to return only samples, not acceptance probabilities.

return samples[:, 0] # Remove batch dimension.

Expand Down Expand Up @@ -221,9 +228,14 @@ def log_prob(
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
assert (
x_density_estimator.shape[0] == 1
), ".log_prob() supports only `batchsize == 1`."
if x_density_estimator.shape[0] > 1:
raise ValueError(
".log_prob() supports only `batchsize == 1`. If you intend "
"to evaluate given multiple observations, use `.log_prob_batched()`. "
"If you intend to evaluate given i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)

self.posterior_estimator.eval()

Expand Down

0 comments on commit 6406139

Please sign in to comment.