Skip to content

Commit

Permalink
bugfix for conditional sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jul 14, 2022
1 parent 7ba1b9d commit d72fc6d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 1 addition & 4 deletions sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,8 @@ def conditional_potential(

condition = atleast_2d_float32_tensor(condition)

# Transform the `condition` to unconstrained space.
transformed_condition = theta_transform(condition)

conditioned_potential_fn = ConditionedPotential(
potential_fn, transformed_condition, dims_to_sample # type: ignore
potential_fn, condition, dims_to_sample # type: ignore
)

restricted_prior = RestrictedPriorForConditional(prior, dims_to_sample)
Expand Down
8 changes: 6 additions & 2 deletions tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sbi import analysis as analysis
from sbi import utils as utils
from sbi.analysis import ConditionedMDN, conditonal_potential
from sbi.analysis import ConditionedMDN, conditional_potential
from sbi.inference import (
SNPE_A,
SNPE_B,
Expand Down Expand Up @@ -468,7 +468,11 @@ def simulator(theta):
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator, prior=prior, x_o=x_o
)
(conditioned_potential_fn, restricted_tf, restricted_prior,) = conditonal_potential(
(
conditioned_potential_fn,
restricted_tf,
restricted_prior,
) = conditional_potential(
potential_fn=potential_fn,
theta_transform=theta_transform,
prior=prior,
Expand Down

0 comments on commit d72fc6d

Please sign in to comment.