Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampling without inferring the prior transform #713

Closed
Patrickens opened this issue Aug 1, 2022 · 5 comments · Fixed by #714
Closed

sampling without inferring the prior transform #713

Patrickens opened this issue Aug 1, 2022 · 5 comments · Fixed by #714

Comments

@Patrickens
Copy link

Patrickens commented Aug 1, 2022

Dear SBI team,

Im absolutely loving the package! Very well written and documented.

I need to fit a distribution that has convex polytope support, meaning that all samples theta need to satisfy the following: A @ theta <= b. I made the following support class for my prior:

from torch.distributions.constraints import Constraint, _Dependent

class _CannonicalPolytopeSupport(_Dependent):
    def __init__(
            self,
            A: torch.Tensor,
            b: torch.Tensor,
            validation_tol = 1e-13,
    ):
        self._A = A 
        self._b = b + validation_tol 
        super().__init__(is_discrete=False, event_dim=self._A.shape[1])

    def check(self, value):
        vape    = value.shape
        viewlue = value.permute((-1, *range(value.ndim -1))).view(vape[-1], vape[:-1].numel())
        valid   = (self._A @ viewlue <= self._b).T
        return valid.view(*vape[:-1], self._A.shape[0])

Now I run into the following issue. Once a round of inference is done and I try to build a posterior, sbi assumes that if you have a bounded and dependent support, that there must be a bijective transform to an unconstrained space:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [32], in <module>
----> 1 posterior = snpec.build_posterior(density_estimator)

File C:\Miniconda3\envs\pta\lib\site-packages\sbi\inference\snpe\snpe_base.py:420, in PosteriorEstimator.build_posterior(self, density_estimator, prior, sample_with, mcmc_method, vi_method, mcmc_parameters, vi_parameters, rejection_sampling_parameters)
    417     # Otherwise, infer it from the device of the net parameters.
    418     device = next(density_estimator.parameters()).device.type
--> 420 potential_fn, theta_transform = posterior_estimator_based_potential(
    421     posterior_estimator=posterior_estimator, prior=prior, x_o=None
    422 )
    424 if sample_with == "rejection":
    425     if "proposal" in rejection_sampling_parameters.keys():

File C:\Miniconda3\envs\pta\lib\site-packages\sbi\inference\potentials\posterior_based_potential.py:46, in posterior_estimator_based_potential(posterior_estimator, prior, x_o)
     41 device = str(next(posterior_estimator.parameters()).device)
     43 potential_fn = PosteriorBasedPotential(
     44     posterior_estimator, prior, x_o, device=device
     45 )
---> 46 theta_transform = mcmc_transform(prior, device=device)
     48 return potential_fn, theta_transform

File C:\Miniconda3\envs\pta\lib\site-packages\sbi\utils\sbiutils.py:615, in mcmc_transform(prior, num_prior_samples_for_zscoring, enable_transform, device, **kwargs)
    613 # Prior with bounded support, e.g., uniform priors.
    614 if has_support and support_is_bounded:
--> 615     transform = biject_to(prior.support)
    616 # For all other cases build affine transform with mean and std.
    617 else:
    618     if hasattr(prior, "mean") and hasattr(prior, "stddev"):

File C:\Miniconda3\envs\pta\lib\site-packages\torch\distributions\constraint_registry.py:142, in ConstraintRegistry.__call__(self, constraint)
    140     factory = self._registry[type(constraint)]
    141 except KeyError:
--> 142     raise NotImplementedError(
    143         f'Cannot transform {type(constraint).__name__} constraints') from None
    144 return factory(constraint)

NotImplementedError: Cannot transform _CannonicalPolytopeSupport constraints

Is there any way around this issue? I've thought of such a transform (have not implemented it yet, since its a real headache), but for now I would like to just check whether a sample is within the support without transforming it to unconstrained space.

Thank you for you efforts!

@Patrickens
Copy link
Author

Patrickens commented Aug 1, 2022

Perhaps I could just use rejection sampling by .build_posterior(density_estimator, sample_with='rejection'), and then we do not use the theta_transform at all, so we just have to build the potential_fun in posterior_estimator_based_potential. Or perhaps we can pass the enable_transform=False kwarg to mcmc_transform from build_posterior somehow.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Aug 1, 2022

Hi there! Thanks for the compliments, we're happy that you're enjoying the package :)

If you want to turn off the tranfsformation, you should probably use the sampler interface.

Could you try this (I did not try it yet):

from sbi.inference import DirectPosterior

inference = SNPE()
posterior_estimator = inference.append_simulations(theta, x).train()
posterior = DirectPosterior(posterior_estimator, prior)

@Patrickens
Copy link
Author

So at the instantiation of DirectPosterior, there is also a call to posterior_estimator_based_potential and thus to mcmc_transform. For now, I inserted a enable_theta_transform=False keyword to the constructor of DirectPosterior which is passed further and this works! Thanks for thinking along.

@michaeldeistler
Copy link
Contributor

Ah, great that you found this fix! I'll leave this issue open because we should probably make this more easy...

@michaeldeistler michaeldeistler changed the title inference for distributions with polytope support sampling without inferring the prior transform Aug 1, 2022
@janfb
Copy link
Contributor

janfb commented Aug 3, 2022

thanks @Patrickens , interesting problem!

But a plain identity transform should work on your custom support, right?

I added a theta_transform=None kwarg to DirectPosterior, which is passed on to posterior_estimator_based_potential to either build an identity transform, or construct one from the support. see #714

Could you please paste your custom prior class here for testing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants