Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make transform optional for DirectPosterior. #714

Merged
merged 1 commit into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union
from warnings import warn

import torch
import torch.distributions.transforms as torch_tf
Expand Down
11 changes: 9 additions & 2 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
posterior_estimator_based_potential,
)
from sbi.samplers.rejection.rejection import rejection_sample_posterior_within_prior
from sbi.types import Shape
from sbi.types import Shape, TorchTransform
from sbi.utils import check_prior, match_theta_and_x_batch_shapes, within_support
from sbi.utils.torchutils import ensure_theta_batched

Expand All @@ -35,6 +35,7 @@ def __init__(
self,
posterior_estimator: flows.Flow,
prior: Distribution,
theta_transform: Optional[TorchTransform] = None,
max_sampling_batch_size: int = 10_000,
device: Optional[str] = None,
x_shape: Optional[torch.Size] = None,
Expand All @@ -43,6 +44,12 @@ def __init__(
Args:
prior: Prior distribution with `.log_prob()` and `.sample()`.
posterior_estimator: The trained neural posterior.
theta_transform: Custom transform to perform MAP optimization in
unconstrained space. If None (default), a suitable transform is
built from the prior support. In order to not use a transform at all,
pass an identity transform, e.g., `theta_transform=torch.distrbutions.
transforms`.
identity_transform()`.
max_sampling_batch_size: Batchsize of samples being drawn from
the proposal at every iteration.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
Expand All @@ -55,7 +62,7 @@ def __init__(
# obtaining the MAP.
check_prior(prior)
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator, prior, None
posterior_estimator, prior, x_o=None, theta_transform=theta_transform
)

super().__init__(
Expand Down
11 changes: 10 additions & 1 deletion sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Optional, Tuple

import torch
import torch.distributions.transforms as torch_tf
from pyknos.nflows import flows
from torch import Tensor, nn
from torch.distributions import Distribution
Expand All @@ -19,6 +20,7 @@ def posterior_estimator_based_potential(
posterior_estimator: nn.Module,
prior: Distribution,
x_o: Optional[Tensor],
theta_transform: Optional[TorchTransform] = None,
) -> Tuple[Callable, TorchTransform]:
r"""Returns the potential for posterior-based methods.

Expand All @@ -32,6 +34,11 @@ def posterior_estimator_based_potential(
posterior_estimator: The neural network modelling the posterior.
prior: The prior distribution.
x_o: The observed data at which to evaluate the posterior.
theta_transform: Transform to map the parameters to an
unconstrained space. If None (default), a suitable transform is
built from the prior support. In order to not use a transform at all,
pass an identity transform, e.g., `theta_transform=torch.distrbutions.
transforms`.

Returns:
The potential function and a transformation that maps
Expand All @@ -43,7 +50,9 @@ def posterior_estimator_based_potential(
potential_fn = PosteriorBasedPotential(
posterior_estimator, prior, x_o, device=device
)
theta_transform = mcmc_transform(prior, device=device)

if theta_transform is None:
theta_transform = mcmc_transform(prior, device=device)

return potential_fn, theta_transform

Expand Down