diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index fcf97a17f..293271df5 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -293,11 +293,8 @@ def conditional_potential( condition = atleast_2d_float32_tensor(condition) - # Transform the `condition` to unconstrained space. - transformed_condition = theta_transform(condition) - conditioned_potential_fn = ConditionedPotential( - potential_fn, transformed_condition, dims_to_sample # type: ignore + potential_fn, condition, dims_to_sample # type: ignore ) restricted_prior = RestrictedPriorForConditional(prior, dims_to_sample) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index a619de8b1..f0623e269 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -12,7 +12,7 @@ from sbi import analysis as analysis from sbi import utils as utils -from sbi.analysis import ConditionedMDN, conditonal_potential +from sbi.analysis import ConditionedMDN, conditional_potential from sbi.inference import ( SNPE_A, SNPE_B, @@ -468,7 +468,11 @@ def simulator(theta): potential_fn, theta_transform = posterior_estimator_based_potential( posterior_estimator, prior=prior, x_o=x_o ) - (conditioned_potential_fn, restricted_tf, restricted_prior,) = conditonal_potential( + ( + conditioned_potential_fn, + restricted_tf, + restricted_prior, + ) = conditional_potential( potential_fn=potential_fn, theta_transform=theta_transform, prior=prior,