Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate new density estimator interface into SBI #979

Merged
merged 39 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5440446
build functions return density estimators
gmoss13 Mar 7, 2024
fdd6c05
snle_a uses DensityEstimator instead of nn.Module
gmoss13 Mar 7, 2024
d8c126c
build_nn context shape should be of unembedded context
gmoss13 Mar 7, 2024
537737c
zukoFlow density estimator and some tests
gmoss13 Mar 7, 2024
fc00b5b
formatting
gmoss13 Mar 7, 2024
c2c1b47
revert estimator of base NeuralInference back to self._neural_net
gmoss13 Mar 8, 2024
d240d40
NPE integartion first approach
manuelgloeckler Mar 11, 2024
066fc6e
Merge branch 'SNLE_density_estimators' of https://github.com/sbi-dev/…
manuelgloeckler Mar 11, 2024
d57030a
Fix rejection bug, ignore pyright for mnle for the moment
manuelgloeckler Mar 11, 2024
0f8541d
Fix mnle bug context -> condition
manuelgloeckler Mar 11, 2024
5934062
Temp fix of SNPE_A and C ... Need to normalize interface MOG...
manuelgloeckler Mar 11, 2024
73fa99d
Fix multi round bugs
manuelgloeckler Mar 11, 2024
b1cb88e
Fixing hopefully last tests that fail
manuelgloeckler Mar 11, 2024
a5e6fc3
If else MAF possibly unbound fix
manuelgloeckler Mar 11, 2024
1ca7ab2
Reverting change to notebook :/
manuelgloeckler Mar 11, 2024
4fbb3ff
add embedding net to ZukoFlow
gmoss13 Mar 11, 2024
5f25941
add zuko_maf_builder and test for SNLE
gmoss13 Mar 11, 2024
eb2c301
formatting
gmoss13 Mar 11, 2024
f67f0b8
Merge branch 'SNLE_density_estimators' of https://github.com/sbi-dev/…
manuelgloeckler Mar 11, 2024
6c2060c
Using the loss of the density estimator for single round
manuelgloeckler Mar 11, 2024
3cf6acd
Formatting
manuelgloeckler Mar 11, 2024
587321e
expose embedding_net property of density estimators
gmoss13 Mar 12, 2024
e19c96e
add test for ZukoFlow for SNPE
gmoss13 Mar 12, 2024
05b78cf
remove sign snpe bug
manuelgloeckler Mar 11, 2024
2614761
Merge branch 'main' into SNLE_density_estimators
manuelgloeckler Mar 12, 2024
ecefd85
New ruff formating
manuelgloeckler Mar 12, 2024
a3325b6
ruff-format
manuelgloeckler Mar 12, 2024
63ab9a1
Ruff suggestions
manuelgloeckler Mar 12, 2024
5511704
pyright fix
manuelgloeckler Mar 12, 2024
b3d8ba4
Final formatting
manuelgloeckler Mar 12, 2024
56260e1
add RatioEstimator
gmoss13 Mar 12, 2024
ab7d0f8
Merge branch 'main' into SNLE_density_estimators
manuelgloeckler Mar 13, 2024
787e24b
Recommended changes
manuelgloeckler Mar 13, 2024
1dee434
remaining changes
gmoss13 Mar 13, 2024
c9b1816
remove RatioEstimator from this PR
gmoss13 Mar 13, 2024
763ded1
fix test failure in likelihood_potential
gmoss13 Mar 13, 2024
4cbdaae
device check in likelihood_based_potential
gmoss13 Mar 14, 2024
65af0f4
Recommended changes
manuelgloeckler Mar 14, 2024
b4b00ba
Integrated last recommendations
manuelgloeckler Mar 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"tensorboard",
"torch>=1.8.0",
"tqdm",
"zuko>=1.0.0",
]

[project.optional-dependencies]
Expand Down
9 changes: 6 additions & 3 deletions sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch
import torch.distributions.transforms as torch_tf
from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn
from torch import Tensor, nn
from torch import Tensor
from torch.distributions import Distribution

from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.types import Shape, TorchTransform
from sbi.utils.conditional_density_utils import (
ConditionedPotential,
Expand Down Expand Up @@ -185,7 +186,7 @@ def conditional_corrcoeff(
class ConditionedMDN:
def __init__(
self,
net: nn.Module,
gaussianMDN: DensityEstimator,
x_o: Tensor,
condition: Tensor,
dims_to_sample: List[int],
Expand All @@ -205,7 +206,9 @@ def __init__(
"""
condition = atleast_2d_float32_tensor(condition)

logits, means, precfs, _ = extract_and_transform_mog(net=net, context=x_o)
logits, means, precfs, _ = extract_and_transform_mog(
net=gaussianMDN.net, context=x_o
)
self.logits, self.means, self.precfs, self.sumlogdiag = condition_mog(
condition, dims_to_sample, logits, means, precfs
)
Expand Down
38 changes: 28 additions & 10 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
from typing import Optional, Union

import torch
from pyknos.nflows import flows
from torch import Tensor, log
from torch.distributions import Distribution

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.posterior_based_potential import (
posterior_estimator_based_potential,
)
from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.samplers.rejection.rejection import accept_reject_sample
from sbi.types import Shape
from sbi.utils import check_prior, match_theta_and_x_batch_shapes, within_support
from sbi.utils import check_prior, within_support
from sbi.utils.torchutils import ensure_theta_batched


Expand All @@ -33,7 +33,7 @@

def __init__(
self,
posterior_estimator: flows.Flow,
posterior_estimator: DensityEstimator,
prior: Distribution,
max_sampling_batch_size: int = 10_000,
device: Optional[str] = None,
Expand Down Expand Up @@ -101,7 +101,18 @@
"""

num_samples = torch.Size(sample_shape).numel()
condition_shape = self.posterior_estimator._condition_shape
x = self._x_else_default_x(x)

try:
x = x.reshape(*condition_shape)
except RuntimeError as err:
raise ValueError(

Check warning on line 110 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L109-L110

Added lines #L109 - L110 were not covered by tests
f"Expected a single `x` which should broadcastable to shape \
{condition_shape}, but got {x.shape}. For batched eval \
see issue #990"
) from err

max_sampling_batch_size = (
self.max_sampling_batch_size
if max_sampling_batch_size is None
Expand All @@ -121,9 +132,10 @@
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"context": x},
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
)[0]

return samples

def log_prob(
Expand Down Expand Up @@ -159,21 +171,27 @@
support of the prior, -∞ (corresponding to 0 probability) outside.
"""
x = self._x_else_default_x(x)
condition_shape = self.posterior_estimator._condition_shape
try:
x = x.reshape(*condition_shape)
except RuntimeError as err:
raise ValueError(

Check warning on line 178 in sbi/inference/posteriors/direct_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/direct_posterior.py#L177-L178

Added lines #L177 - L178 were not covered by tests
f"Expected a single `x` which should broadcastable to shape \
{condition_shape}, but got {x.shape}. For batched eval \
see issue #990"
) from err

# TODO Train exited here, entered after sampling?
self.posterior_estimator.eval()

theta = ensure_theta_batched(torch.as_tensor(theta))
theta_repeated, x_repeated = match_theta_and_x_batch_shapes(theta, x)

with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.posterior_estimator.log_prob(
theta_repeated, context=x_repeated
)
unnorm_log_prob = self.posterior_estimator.log_prob(theta, condition=x)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta_repeated)
in_prior_support = within_support(self.prior, theta)

masked_log_prob = torch.where(
in_prior_support,
Expand Down Expand Up @@ -227,7 +245,7 @@
show_progress_bars=show_progress_bars,
sample_for_correction_factor=True,
max_sampling_batch_size=rejection_sampling_batch_size,
proposal_sampling_kwargs={"context": x},
proposal_sampling_kwargs={"condition": x},
)[1]

# Check if the provided x matches the default x (short-circuit on identity).
Expand Down
54 changes: 21 additions & 33 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from typing import Any, Callable, Optional, Tuple
from typing import Callable, Optional, Tuple

import torch
from torch import Tensor, nn
from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.neural_nets.mnle import MixedDensityEstimator
from sbi.types import TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.sbiutils import match_theta_and_x_batch_shapes
from sbi.utils.torchutils import atleast_2d


def likelihood_estimator_based_potential(
likelihood_estimator: nn.Module,
likelihood_estimator: DensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
enable_transform: bool = True,
Expand All @@ -27,7 +26,7 @@ def likelihood_estimator_based_potential(
unconstrained space.

Args:
likelihood_estimator: The neural network modelling the likelihood.
likelihood_estimator: The density estimator modelling the likelihood.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood.
enable_transform: Whether to transform parameters to unconstrained space.
Expand Down Expand Up @@ -55,15 +54,15 @@ class LikelihoodBasedPotential(BasePotential):

def __init__(
self,
likelihood_estimator: nn.Module,
likelihood_estimator: DensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
):
r"""Returns the potential function for likelihood-based methods.

Args:
likelihood_estimator: The neural network modelling the likelihood.
likelihood_estimator: The density estimator modelling the likelihood.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood.
device: The device to which parameters and data are moved before evaluating
Expand Down Expand Up @@ -92,15 +91,15 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
log_likelihood_trial_sum = _log_likelihoods_over_trials(
x=self.x_o,
theta=theta.to(self.device),
net=self.likelihood_estimator,
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)

return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore


def _log_likelihoods_over_trials(
x: Tensor, theta: Tensor, net: Any, track_gradients: bool = False
x: Tensor, theta: Tensor, estimator: DensityEstimator, track_gradients: bool = False
) -> Tensor:
r"""Return log likelihoods summed over iid trials of `x`.

Expand All @@ -112,36 +111,23 @@ def _log_likelihoods_over_trials(

Args:
x: batch of iid data.
theta: batch of parameters
net: neural net with .log_prob()
theta: batch of parameters.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.

Returns:
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
batch entries (iid trials) in `x`.
"""

# Repeat `x` in case of evaluation on multiple `theta`. This is needed below in
# when calling nflows in order to have matching shapes of theta and context x
# at neural network evaluation time.
theta_repeated, x_repeated = match_theta_and_x_batch_shapes(
theta=atleast_2d(theta), x=atleast_2d(x)
)
assert (
x_repeated.shape[0] == theta_repeated.shape[0]
), "x and theta must match in batch shape."
assert (
next(net.parameters()).device == x.device and x.device == theta.device
), f"""device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device},
{theta.device}."""
# unsqueeze to ensure that the x-batch dimension is the first dimension for the
# broadcasting of the density estimator.
x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1)

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
log_likelihood_trial_batch = net.log_prob(x_repeated, theta_repeated)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
x.shape[0], -1
).sum(0)
log_likelihood_trial_batch = estimator.log_prob(x, condition=theta)
# Reshape to (-1, theta_batch_size), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.sum(0)

return log_likelihood_trial_sum

Expand Down Expand Up @@ -179,12 +165,14 @@ def mixed_likelihood_estimator_based_potential(
class MixedLikelihoodBasedPotential(LikelihoodBasedPotential):
def __init__(
self,
likelihood_estimator: MixedDensityEstimator,
likelihood_estimator: MixedDensityEstimator, # type: ignore TODO fix pyright
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
):
super().__init__(likelihood_estimator, prior, x_o, device)
# TODO Fix pyright issue by making MixedDensityEstimator a subclass
# of DensityEstimator
super().__init__(likelihood_estimator, prior, x_o, device) # type: ignore

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
# Calculate likelihood in one batch.
Expand Down
23 changes: 13 additions & 10 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
from typing import Callable, Optional, Tuple

import torch
from pyknos.nflows import flows
from torch import Tensor, nn
from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.types import TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.sbiutils import match_theta_and_x_batch_shapes, within_support
from sbi.utils.sbiutils import within_support
from sbi.utils.torchutils import ensure_theta_batched


def posterior_estimator_based_potential(
posterior_estimator: nn.Module,
posterior_estimator: DensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
enable_transform: bool = True,
Expand Down Expand Up @@ -59,7 +59,7 @@

def __init__(
self,
posterior_estimator: flows.Flow,
posterior_estimator: DensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
Expand Down Expand Up @@ -92,14 +92,17 @@
The potential.
"""

if self._x_o is None:
raise ValueError(

Check warning on line 96 in sbi/inference/potentials/posterior_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/posterior_based_potential.py#L96

Added line #L96 was not covered by tests
"No observed data x_o is available. Please reinitialize \
the potential or manually set self._x_o."
)

theta = ensure_theta_batched(torch.as_tensor(theta))
theta, x_repeated = match_theta_and_x_batch_shapes(theta, self.x_o)
theta, x_repeated = theta.to(self.device), x_repeated.to(self.device)
theta, x = theta.to(self.device), self.x_o.to(self.device)

with torch.set_grad_enabled(track_gradients):
posterior_log_prob = self.posterior_estimator.log_prob(
theta, context=x_repeated
)
posterior_log_prob = self.posterior_estimator.log_prob(theta, condition=x)

# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
Expand Down
12 changes: 12 additions & 0 deletions sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Union

from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
Expand Down Expand Up @@ -194,3 +195,14 @@ def build_posterior(
self._model_bank.append(deepcopy(self._posterior))

return deepcopy(self._posterior)

# Temporary: need to rewrite mixed likelihood estimators as DensityEstimator
# objects.
# TODO: Fix and merge issue #968
def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$.

Returns:
Negative log prob.
"""
return -self._neural_net.log_prob(x, context=theta)
10 changes: 5 additions & 5 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from typing import Any, Callable, Dict, Optional, Union

import torch
from pyknos.nflows import flows
from torch import Tensor, nn, optim
from torch import Tensor, optim
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.tensorboard.writer import SummaryWriter
Expand All @@ -17,6 +16,7 @@
from sbi.inference import NeuralInference
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation


Expand Down Expand Up @@ -126,7 +126,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> flows.Flow:
) -> DensityEstimator:
r"""Train the density estimator to learn the distribution $p(x|\theta)$.

Args:
Expand Down Expand Up @@ -262,7 +262,7 @@ def train(

def build_posterior(
self,
density_estimator: Optional[nn.Module] = None,
density_estimator: Optional[DensityEstimator] = None,
prior: Optional[Distribution] = None,
sample_with: str = "mcmc",
mcmc_method: str = "slice_np",
Expand Down Expand Up @@ -367,4 +367,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
return -self._neural_net.log_prob(x, context=theta)
return self._neural_net.loss(x, condition=theta)
Loading