From 8c4d3e95bf47ac9bc822e7bd97cf4ca93791e56b Mon Sep 17 00:00:00 2001 From: janfb Date: Wed, 3 Aug 2022 15:56:14 +0200 Subject: [PATCH] add option to pass theta transform to DirectPosterior. add test use transform for DirectPosterior as default. --- sbi/inference/posteriors/base_posterior.py | 1 - sbi/inference/posteriors/direct_posterior.py | 11 +++++++++-- sbi/inference/potentials/posterior_based_potential.py | 11 ++++++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 302764a51..bae048ca8 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Union -from warnings import warn import torch import torch.distributions.transforms as torch_tf diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index bfeeea526..7397b44b6 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -12,7 +12,7 @@ posterior_estimator_based_potential, ) from sbi.samplers.rejection.rejection import rejection_sample_posterior_within_prior -from sbi.types import Shape +from sbi.types import Shape, TorchTransform from sbi.utils import check_prior, match_theta_and_x_batch_shapes, within_support from sbi.utils.torchutils import ensure_theta_batched @@ -35,6 +35,7 @@ def __init__( self, posterior_estimator: flows.Flow, prior: Distribution, + theta_transform: Optional[TorchTransform] = None, max_sampling_batch_size: int = 10_000, device: Optional[str] = None, x_shape: Optional[torch.Size] = None, @@ -43,6 +44,12 @@ def __init__( Args: prior: Prior distribution with `.log_prob()` and `.sample()`. posterior_estimator: The trained neural posterior. + theta_transform: Custom transform to perform MAP optimization in + unconstrained space. If None (default), a suitable transform is + built from the prior support. In order to not use a transform at all, + pass an identity transform, e.g., `theta_transform=torch.distrbutions. + transforms`. + identity_transform()`. max_sampling_batch_size: Batchsize of samples being drawn from the proposal at every iteration. device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None, @@ -55,7 +62,7 @@ def __init__( # obtaining the MAP. check_prior(prior) potential_fn, theta_transform = posterior_estimator_based_potential( - posterior_estimator, prior, None + posterior_estimator, prior, x_o=None, theta_transform=theta_transform ) super().__init__( diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 9104ff3d6..5361bab22 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -4,6 +4,7 @@ from typing import Callable, Optional, Tuple import torch +import torch.distributions.transforms as torch_tf from pyknos.nflows import flows from torch import Tensor, nn from torch.distributions import Distribution @@ -19,6 +20,7 @@ def posterior_estimator_based_potential( posterior_estimator: nn.Module, prior: Distribution, x_o: Optional[Tensor], + theta_transform: Optional[TorchTransform] = None, ) -> Tuple[Callable, TorchTransform]: r"""Returns the potential for posterior-based methods. @@ -32,6 +34,11 @@ def posterior_estimator_based_potential( posterior_estimator: The neural network modelling the posterior. prior: The prior distribution. x_o: The observed data at which to evaluate the posterior. + theta_transform: Transform to map the parameters to an + unconstrained space. If None (default), a suitable transform is + built from the prior support. In order to not use a transform at all, + pass an identity transform, e.g., `theta_transform=torch.distrbutions. + transforms`. Returns: The potential function and a transformation that maps @@ -43,7 +50,9 @@ def posterior_estimator_based_potential( potential_fn = PosteriorBasedPotential( posterior_estimator, prior, x_o, device=device ) - theta_transform = mcmc_transform(prior, device=device) + + if theta_transform is None: + theta_transform = mcmc_transform(prior, device=device) return potential_fn, theta_transform