diff --git a/pyproject.toml b/pyproject.toml index 53796da17..42f1e415e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "tensorboard", "torch>=1.8.0", "tqdm", + "zuko>=1.0.0", ] [project.optional-dependencies] diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index a510df919..3a7f9a010 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -6,10 +6,11 @@ import torch import torch.distributions.transforms as torch_tf -from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn -from torch import Tensor, nn +from pyknos.mdn.mdn import MultivariateGaussianMDN +from torch import Tensor from torch.distributions import Distribution +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape, TorchTransform from sbi.utils.conditional_density_utils import ( ConditionedPotential, @@ -185,7 +186,7 @@ def conditional_corrcoeff( class ConditionedMDN: def __init__( self, - net: nn.Module, + mdn: DensityEstimator, x_o: Tensor, condition: Tensor, dims_to_sample: List[int], @@ -205,7 +206,7 @@ def __init__( """ condition = atleast_2d_float32_tensor(condition) - logits, means, precfs, _ = extract_and_transform_mog(net=net, context=x_o) + logits, means, precfs, _ = extract_and_transform_mog(net=mdn.net, context=x_o) self.logits, self.means, self.precfs, self.sumlogdiag = condition_mog( condition, dims_to_sample, logits, means, precfs ) @@ -213,13 +214,15 @@ def __init__( def sample(self, sample_shape: Shape = torch.Size()) -> Tensor: num_samples = torch.Size(sample_shape).numel() - samples = mdn.sample_mog(num_samples, self.logits, self.means, self.precfs) + samples = MultivariateGaussianMDN.sample_mog( + num_samples, self.logits, self.means, self.precfs + ) return samples.detach().reshape((*sample_shape, -1)) def log_prob(self, theta: Tensor) -> Tensor: batch_size, dim = theta.shape - log_prob = mdn.log_prob_mog( + log_prob = MultivariateGaussianMDN.log_prob_mog( theta, self.logits.repeat(batch_size, 1), self.means.repeat(batch_size, 1, 1), diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index e1cf8dd4e..8d454e1c6 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -3,7 +3,6 @@ from typing import Optional, Union import torch -from pyknos.nflows import flows from torch import Tensor, log from torch.distributions import Distribution @@ -11,9 +10,10 @@ from sbi.inference.potentials.posterior_based_potential import ( posterior_estimator_based_potential, ) +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.samplers.rejection.rejection import accept_reject_sample from sbi.types import Shape -from sbi.utils import check_prior, match_theta_and_x_batch_shapes, within_support +from sbi.utils import check_prior, within_support from sbi.utils.torchutils import ensure_theta_batched @@ -33,7 +33,7 @@ class DirectPosterior(NeuralPosterior): def __init__( self, - posterior_estimator: flows.Flow, + posterior_estimator: DensityEstimator, prior: Distribution, max_sampling_batch_size: int = 10_000, device: Optional[str] = None, @@ -101,7 +101,18 @@ def sample( """ num_samples = torch.Size(sample_shape).numel() + condition_shape = self.posterior_estimator._condition_shape x = self._x_else_default_x(x) + + try: + x = x.reshape(*condition_shape) + except RuntimeError as err: + raise ValueError( + f"Expected a single `x` which should broadcastable to shape \ + {condition_shape}, but got {x.shape}. For batched eval \ + see issue #990" + ) from err + max_sampling_batch_size = ( self.max_sampling_batch_size if max_sampling_batch_size is None @@ -121,9 +132,10 @@ def sample( num_samples=num_samples, show_progress_bars=show_progress_bars, max_sampling_batch_size=max_sampling_batch_size, - proposal_sampling_kwargs={"context": x}, + proposal_sampling_kwargs={"condition": x}, alternative_method="build_posterior(..., sample_with='mcmc')", )[0] + return samples def log_prob( @@ -159,21 +171,27 @@ def log_prob( support of the prior, -∞ (corresponding to 0 probability) outside. """ x = self._x_else_default_x(x) + condition_shape = self.posterior_estimator._condition_shape + try: + x = x.reshape(*condition_shape) + except RuntimeError as err: + raise ValueError( + f"Expected a single `x` which should broadcastable to shape \ + {condition_shape}, but got {x.shape}. For batched eval \ + see issue #990" + ) from err # TODO Train exited here, entered after sampling? self.posterior_estimator.eval() theta = ensure_theta_batched(torch.as_tensor(theta)) - theta_repeated, x_repeated = match_theta_and_x_batch_shapes(theta, x) with torch.set_grad_enabled(track_gradients): # Evaluate on device, move back to cpu for comparison with prior. - unnorm_log_prob = self.posterior_estimator.log_prob( - theta_repeated, context=x_repeated - ) + unnorm_log_prob = self.posterior_estimator.log_prob(theta, condition=x) # Force probability to be zero outside prior support. - in_prior_support = within_support(self.prior, theta_repeated) + in_prior_support = within_support(self.prior, theta) masked_log_prob = torch.where( in_prior_support, @@ -227,7 +245,7 @@ def acceptance_at(x: Tensor) -> Tensor: show_progress_bars=show_progress_bars, sample_for_correction_factor=True, max_sampling_batch_size=rejection_sampling_batch_size, - proposal_sampling_kwargs={"context": x}, + proposal_sampling_kwargs={"condition": x}, )[1] # Check if the provided x matches the default x (short-circuit on identity). diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index d3aa5e51d..cb431fb4a 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -1,22 +1,21 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . -from typing import Any, Callable, Optional, Tuple +from typing import Callable, Optional, Tuple import torch -from torch import Tensor, nn +from torch import Tensor from torch.distributions import Distribution from sbi.inference.potentials.base_potential import BasePotential +from sbi.neural_nets.density_estimators import DensityEstimator from sbi.neural_nets.mnle import MixedDensityEstimator from sbi.types import TorchTransform from sbi.utils import mcmc_transform -from sbi.utils.sbiutils import match_theta_and_x_batch_shapes -from sbi.utils.torchutils import atleast_2d def likelihood_estimator_based_potential( - likelihood_estimator: nn.Module, + likelihood_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], enable_transform: bool = True, @@ -27,7 +26,7 @@ def likelihood_estimator_based_potential( unconstrained space. Args: - likelihood_estimator: The neural network modelling the likelihood. + likelihood_estimator: The density estimator modelling the likelihood. prior: The prior distribution. x_o: The observed data at which to evaluate the likelihood. enable_transform: Whether to transform parameters to unconstrained space. @@ -55,7 +54,7 @@ class LikelihoodBasedPotential(BasePotential): def __init__( self, - likelihood_estimator: nn.Module, + likelihood_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", @@ -63,7 +62,7 @@ def __init__( r"""Returns the potential function for likelihood-based methods. Args: - likelihood_estimator: The neural network modelling the likelihood. + likelihood_estimator: The density estimator modelling the likelihood. prior: The prior distribution. x_o: The observed data at which to evaluate the likelihood. device: The device to which parameters and data are moved before evaluating @@ -92,7 +91,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: log_likelihood_trial_sum = _log_likelihoods_over_trials( x=self.x_o, theta=theta.to(self.device), - net=self.likelihood_estimator, + estimator=self.likelihood_estimator, track_gradients=track_gradients, ) @@ -100,7 +99,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: def _log_likelihoods_over_trials( - x: Tensor, theta: Tensor, net: Any, track_gradients: bool = False + x: Tensor, theta: Tensor, estimator: DensityEstimator, track_gradients: bool = False ) -> Tensor: r"""Return log likelihoods summed over iid trials of `x`. @@ -112,36 +111,28 @@ def _log_likelihoods_over_trials( Args: x: batch of iid data. - theta: batch of parameters - net: neural net with .log_prob() + theta: batch of parameters. + estimator: DensityEstimator. track_gradients: Whether to track gradients. Returns: log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ - - # Repeat `x` in case of evaluation on multiple `theta`. This is needed below in - # when calling nflows in order to have matching shapes of theta and context x - # at neural network evaluation time. - theta_repeated, x_repeated = match_theta_and_x_batch_shapes( - theta=atleast_2d(theta), x=atleast_2d(x) - ) - assert ( - x_repeated.shape[0] == theta_repeated.shape[0] - ), "x and theta must match in batch shape." + # unsqueeze to ensure that the x-batch dimension is the first dimension for the + # broadcasting of the density estimator. + x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1) assert ( - next(net.parameters()).device == x.device and x.device == theta.device - ), f"""device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device}, + next(estimator.parameters()).device == x.device and x.device == theta.device + ), f"""device mismatch: estimator, x, theta: \ + {next(estimator.parameters()).device}, {x.device}, {theta.device}.""" # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): - log_likelihood_trial_batch = net.log_prob(x_repeated, theta_repeated) - # Reshape to (x-trials x parameters), sum over trial-log likelihoods. - log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( - x.shape[0], -1 - ).sum(0) + log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) + # Reshape to (-1, theta_batch_size), sum over trial-log likelihoods. + log_likelihood_trial_sum = log_likelihood_trial_batch.sum(0) return log_likelihood_trial_sum @@ -179,12 +170,14 @@ def mixed_likelihood_estimator_based_potential( class MixedLikelihoodBasedPotential(LikelihoodBasedPotential): def __init__( self, - likelihood_estimator: MixedDensityEstimator, + likelihood_estimator: MixedDensityEstimator, # type: ignore TODO fix pyright prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", ): - super().__init__(likelihood_estimator, prior, x_o, device) + # TODO Fix pyright issue by making MixedDensityEstimator a subclass + # of DensityEstimator + super().__init__(likelihood_estimator, prior, x_o, device) # type: ignore def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: # Calculate likelihood in one batch. diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 8582041de..27fcc6324 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -4,19 +4,19 @@ from typing import Callable, Optional, Tuple import torch -from pyknos.nflows import flows -from torch import Tensor, nn +from torch import Tensor from torch.distributions import Distribution from sbi.inference.potentials.base_potential import BasePotential +from sbi.neural_nets.density_estimators import DensityEstimator from sbi.types import TorchTransform from sbi.utils import mcmc_transform -from sbi.utils.sbiutils import match_theta_and_x_batch_shapes, within_support +from sbi.utils.sbiutils import within_support from sbi.utils.torchutils import ensure_theta_batched def posterior_estimator_based_potential( - posterior_estimator: nn.Module, + posterior_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], enable_transform: bool = True, @@ -59,7 +59,7 @@ class PosteriorBasedPotential(BasePotential): def __init__( self, - posterior_estimator: flows.Flow, + posterior_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", @@ -92,14 +92,17 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: The potential. """ + if self._x_o is None: + raise ValueError( + "No observed data x_o is available. Please reinitialize \ + the potential or manually set self._x_o." + ) + theta = ensure_theta_batched(torch.as_tensor(theta)) - theta, x_repeated = match_theta_and_x_batch_shapes(theta, self.x_o) - theta, x_repeated = theta.to(self.device), x_repeated.to(self.device) + theta, x = theta.to(self.device), self.x_o.to(self.device) with torch.set_grad_enabled(track_gradients): - posterior_log_prob = self.posterior_estimator.log_prob( - theta, context=x_repeated - ) + posterior_log_prob = self.posterior_estimator.log_prob(theta, condition=x) # Force probability to be zero outside prior support. in_prior_support = within_support(self.prior, theta) diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index b83bd273b..57ac0a425 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -5,6 +5,7 @@ from copy import deepcopy from typing import Any, Callable, Dict, Optional, Union +from torch import Tensor from torch.distributions import Distribution from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior @@ -194,3 +195,14 @@ def build_posterior( self._model_bank.append(deepcopy(self._posterior)) return deepcopy(self._posterior) + + # Temporary: need to rewrite mixed likelihood estimators as DensityEstimator + # objects. + # TODO: Fix and merge issue #968 + def _loss(self, theta: Tensor, x: Tensor) -> Tensor: + r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$. + + Returns: + Negative log prob. + """ + return -self._neural_net.log_prob(x, context=theta) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index e97517f33..7cfb713b2 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -7,8 +7,7 @@ from typing import Any, Callable, Dict, Optional, Union import torch -from pyknos.nflows import flows -from torch import Tensor, nn, optim +from torch import Tensor, optim from torch.distributions import Distribution from torch.nn.utils.clip_grad import clip_grad_norm_ from torch.utils.tensorboard.writer import SummaryWriter @@ -17,6 +16,7 @@ from sbi.inference import NeuralInference from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior from sbi.inference.potentials import likelihood_estimator_based_potential +from sbi.neural_nets.density_estimators import DensityEstimator from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation @@ -126,7 +126,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, - ) -> flows.Flow: + ) -> DensityEstimator: r"""Train the density estimator to learn the distribution $p(x|\theta)$. Args: @@ -262,7 +262,7 @@ def train( def build_posterior( self, - density_estimator: Optional[nn.Module] = None, + density_estimator: Optional[DensityEstimator] = None, prior: Optional[Distribution] = None, sample_with: str = "mcmc", mcmc_method: str = "slice_np", @@ -367,4 +367,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor: Returns: Negative log prob. """ - return -self._neural_net.log_prob(x, context=theta) + return self._neural_net.loss(x, condition=theta) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 79f8db485..36fedcdd3 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -7,9 +7,7 @@ from typing import Any, Callable, Dict, Optional, Union import torch -import torch.nn as nn from pyknos.mdn.mdn import MultivariateGaussianMDN -from pyknos.nflows import flows from pyknos.nflows.transforms import CompositeTransform from torch import Tensor from torch.distributions import Distribution, MultivariateNormal @@ -17,6 +15,7 @@ import sbi.utils as utils from sbi.inference.posteriors.direct_posterior import DirectPosterior from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import TensorboardSummaryWriter, TorchModule from sbi.utils import torchutils @@ -110,7 +109,7 @@ def train( show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, component_perturbation: float = 5e-3, - ) -> nn.Module: + ) -> DensityEstimator: r"""Return density estimator that approximates the proposal posterior. [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional @@ -341,10 +340,10 @@ def _expand_mog(self, eps: float = 1e-5): Args: eps: Standard deviation for the random perturbation. """ - assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) + assert isinstance(self._neural_net.net._distribution, MultivariateGaussianMDN) # Increase the number of components - self._neural_net._distribution._num_components = self._num_components + self._neural_net.net._distribution._num_components = self._num_components # Expand the 1-dim Gaussian. for name, param in self._neural_net.named_parameters(): @@ -361,9 +360,12 @@ def _expand_mog(self, eps: float = 1e-5): param.grad = None # let autograd construct a new gradient -class SNPE_A_MDN(nn.Module): +class SNPE_A_MDN(DensityEstimator): """Generates a posthoc-corrected MDN which approximates the posterior. + TODO: Adapt this class to the new `DensityEstimator` interface. Maybe even to a + common MDN interface. See #989. + This class takes as input the density estimator (abbreviated with `_d` suffix, aka the proposal posterior) and the proposal prior (abbreviated with `_pp` suffix) from which the simulations were drawn. It uses the algorithm presented in SNPE-A [1] to @@ -380,7 +382,7 @@ class SNPE_A_MDN(nn.Module): def __init__( self, - flow: flows.Flow, + flow: DensityEstimator, proposal: Union["utils.BoxUniform", "MultivariateNormal", "DirectPosterior"], prior: Distribution, device: str, @@ -393,7 +395,8 @@ def __init__( prior: The prior distribution. """ # Call nn.Module's constructor. - super().__init__() + + super().__init__(flow, flow._condition_shape) self._neural_net = flow self._prior = prior @@ -418,18 +421,27 @@ def __init__( # Take care of z-scoring, pre-compute and store prior terms. self._set_state_for_mog_proposal() - def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: - inputs, context = inputs.to(self._device), context.to(self._device) + def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor: + """Compute the log-probability of the approximate posterior. + + Args: + inputs: Input values + condition: Condition values + + Returns: + log_prob p(inputs|condition) + """ + inputs, condition = inputs.to(self._device), condition.to(self._device) if not self._apply_correction: - return self._neural_net.log_prob(inputs, context) + return self._neural_net.log_prob(inputs, condition) else: # When we want to compute the approx. posterior, a proposal prior \tilde{p} # has already been observed. To analytically calculate the log-prob of the # Gaussian, we first need to compute the mixture components. # Compute the mixture components of the proposal posterior. - logits_pp, m_pp, prec_pp = self._posthoc_correction(context) + logits_pp, m_pp, prec_pp = self._posthoc_correction(condition) # z-score theta if it z-scoring had been requested. theta = self._maybe_z_score_theta(inputs) @@ -443,16 +455,30 @@ def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: ) return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] - def sample(self, num_samples: int, context: Tensor, batch_size: int = 1) -> Tensor: - context = context.to(self._device) + def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: + """Sample from the approximate posterior. + + Args: + sample_shape: Shape of the samples. + condition: Condition values. + + Returns: + Samples from the approximate posterior. + """ + + condition = condition.to(self._device) if not self._apply_correction: - return self._neural_net.sample(num_samples, context, batch_size) + return self._neural_net.sample(sample_shape, condition=condition) else: # When we want to sample from the approx. posterior, a proposal prior # \tilde{p} has already been observed. To analytically calculate the # log-prob of the Gaussian, we first need to compute the mixture components. - return self._sample_approx_posterior_mog(num_samples, context, batch_size) + num_samples = torch.Size(sample_shape).numel() + condition_ndim = len(self._condition_shape) + batch_size = condition.shape[:-condition_ndim] + batch_size = torch.Size(batch_size).numel() + return self._sample_approx_posterior_mog(num_samples, condition, batch_size) def _sample_approx_posterior_mog( self, num_samples, x: Tensor, batch_size: int @@ -491,7 +517,7 @@ def _sample_approx_posterior_mog( num_samples, logits_p, m_p, prec_factors_p ) - embedded_context = self._neural_net._embedding_net(x) + embedded_context = self._neural_net.net._embedding_net(x) if embedded_context is not None: # Merge the context dimension with sample dimension in order to # apply the transform. @@ -500,7 +526,9 @@ def _sample_approx_posterior_mog( embedded_context, num_reps=num_samples ) - theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) + theta, _ = self._neural_net.net._transform.inverse( + theta, context=embedded_context + ) if embedded_context is not None: # Split the context dimension from sample dimension. @@ -521,9 +549,9 @@ def _posthoc_correction(self, x: Tensor): """ # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. - logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) + embedded_x = self._neural_net.net._embedding_net(x) + dist = self._neural_net.net._distribution # defined to avoid black formatting. + logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(embedded_x) norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) # The following if case is needed because, in the constructor, we call @@ -616,7 +644,9 @@ def _set_state_for_mog_proposal(self) -> None: training step if the prior is Gaussian. """ - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + self.z_score_theta = isinstance( + self._neural_net.net._transform, CompositeTransform + ) self._set_maybe_z_scored_prior() @@ -648,8 +678,8 @@ def _set_maybe_z_scored_prior(self) -> None: prior will not be exactly have mean=0 and std=1. """ if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift + scale = self._neural_net.net._transform._transforms[0]._scale + shift = self._neural_net.net._transform._transforms[0]._shift # Following the definition of the linear transform in # `standardizing_transform` in `sbiutils.py`: @@ -683,7 +713,7 @@ def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: """Return potentially standardized theta if z-scoring was requested.""" if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) + theta, _ = self._neural_net.net._transform(theta) return theta diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index ec665c32b..394139e64 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -7,7 +7,7 @@ from warnings import warn import torch -from torch import Tensor, nn, ones, optim +from torch import Tensor, ones, optim from torch.distributions import Distribution from torch.nn.utils.clip_grad import clip_grad_norm_ from torch.utils.tensorboard.writer import SummaryWriter @@ -22,6 +22,7 @@ ) from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.potentials import posterior_estimator_based_potential +from sbi.neural_nets.density_estimators import DensityEstimator from sbi.utils import ( RestrictedPrior, check_estimator_arg, @@ -215,7 +216,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[dict] = None, - ) -> nn.Module: + ) -> DensityEstimator: r"""Return density estimator that approximates the distribution $p(\theta|x)$. Args: @@ -428,7 +429,7 @@ def default_calibration_kernel(x): def build_posterior( self, - density_estimator: Optional[nn.Module] = None, + density_estimator: Optional[DensityEstimator] = None, prior: Optional[Distribution] = None, sample_with: str = "direct", mcmc_method: str = "slice_np", @@ -580,11 +581,13 @@ def _loss( """ if self._round == 0 or force_first_round_loss: # Use posterior log prob (without proposal correction) for first round. - log_prob = self._neural_net.log_prob(theta, x) + loss = self._neural_net.loss(theta, x) else: - log_prob = self._log_prob_proposal_posterior(theta, x, masks, proposal) + # Currently only works for `DensityEstimator` objects. + # Must be extended ones other Estimators are implemented. See #966, + loss = -self._log_prob_proposal_posterior(theta, x, masks, proposal) - return -(calibration_kernel(x) * log_prob) + return calibration_kernel(x) * loss def _check_proposal(self, proposal): """ diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index 7f488efca..a368183b5 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -163,8 +163,8 @@ def train( proposal = self._proposal_roundwise[-1] self.use_non_atomic_loss = ( isinstance(proposal, DirectPosterior) - and isinstance(proposal.posterior_estimator._distribution, mdn) - and isinstance(self._neural_net._distribution, mdn) + and isinstance(proposal.posterior_estimator.net._distribution, mdn) + and isinstance(self._neural_net.net._distribution, mdn) and check_dist_class( self._prior, class_to_check=(Uniform, MultivariateNormal) )[0] @@ -191,7 +191,9 @@ def _set_state_for_mog_proposal(self) -> None: training step if the prior is Gaussian. """ - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + self.z_score_theta = isinstance( + self._neural_net.net._transform, CompositeTransform + ) self._set_maybe_z_scored_prior() @@ -221,8 +223,8 @@ def _set_maybe_z_scored_prior(self) -> None: """ if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift + scale = self._neural_net.net._transform._transforms[0]._scale + shift = self._neural_net.net._transform._transforms[0]._shift # Following the definintion of the linear transform in # `standardizing_transform` in `sbiutils.py`: @@ -276,8 +278,22 @@ def _log_prob_proposal_posterior( """ if self.use_non_atomic_loss: + if not ( + hasattr(self._neural_net.net, "_distribution") + and isinstance(self._neural_net.net._distribution, mdn) + ): + raise ValueError( + "The density estimator must be a MDNtext for non-atomic loss." + ) + return self._log_prob_proposal_posterior_mog(theta, x, proposal) else: + if not hasattr(self._neural_net, "log_prob"): + raise ValueError( + "The neural estimator must have a log_prob method, for\ + atomic loss. It should at best follow the \ + sbi.neural_nets 'DensityEstiamtor' interface." + ) return self._log_prob_proposal_posterior_atomic(theta, x, masks) def _log_prob_proposal_posterior_atomic( @@ -388,16 +404,16 @@ def _log_prob_proposal_posterior_mog( # Evaluate the proposal. MDNs do not have functionality to run the embedding_net # and then get the mixture_components (**without** calling log_prob()). Hence, # we call them separately here. - encoded_x = proposal.posterior_estimator._embedding_net(proposal.default_x) + encoded_x = proposal.posterior_estimator.net._embedding_net(proposal.default_x) dist = ( - proposal.posterior_estimator._distribution + proposal.posterior_estimator.net._distribution ) # defined to avoid ugly black formatting. logits_p, m_p, prec_p, _, _ = dist.get_mixture_components(encoded_x) norm_logits_p = logits_p - torch.logsumexp(logits_p, dim=-1, keepdim=True) # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. + encoded_x = self._neural_net.net._embedding_net(x) + dist = self._neural_net.net._distribution # defined to avoid black formatting. logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) @@ -632,6 +648,6 @@ def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: """Return potentially standardized theta if z-scoring was requested.""" if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) + theta, _ = self._neural_net.net._transform(theta) return theta diff --git a/sbi/inference/snre/snre_c.py b/sbi/inference/snre/snre_c.py index fe98cc054..ad99c39ca 100644 --- a/sbi/inference/snre/snre_c.py +++ b/sbi/inference/snre/snre_c.py @@ -147,7 +147,7 @@ def _loss( assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." batch_size = theta.shape[0] - # We append an contrastive theta to the marginal case because we will remove + # We append a contrastive theta to the marginal case because we will remove # the jointly drawn # sample in the logits_marginal[:, 0] position. That makes the remaining sample # marginally drawn. diff --git a/sbi/neural_nets/density_estimators/__init__.py b/sbi/neural_nets/density_estimators/__init__.py index 5069397e8..0cc2aa068 100644 --- a/sbi/neural_nets/density_estimators/__init__.py +++ b/sbi/neural_nets/density_estimators/__init__.py @@ -1,2 +1,3 @@ from sbi.neural_nets.density_estimators.base import DensityEstimator -from sbi.neural_nets.density_estimators.flow import NFlowsFlow +from sbi.neural_nets.density_estimators.nflows_flow import NFlowsFlow +from sbi.neural_nets.density_estimators.zuko_flow import ZukoFlow diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 55d8637a9..a2af32bc1 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -24,13 +24,18 @@ def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: Args: net: Neural network. - condition_shape: Shape of the input. If not provided, it will assume a 1D - input. + condition_shape: Shape of the condition. If not provided, it will assume a + 1D input. """ super().__init__() self.net = net self._condition_shape = condition_shape + @property + def embedding_net(self) -> Optional[nn.Module]: + r"""Return the embedding network if it exists.""" + return None + def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple i.e. batched conditions. diff --git a/sbi/neural_nets/density_estimators/flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py similarity index 97% rename from sbi/neural_nets/density_estimators/flow.py rename to sbi/neural_nets/density_estimators/nflows_flow.py index a2510cfd2..ad578d96c 100644 --- a/sbi/neural_nets/density_estimators/flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -2,7 +2,7 @@ import torch from pyknos.nflows.flows import Flow -from torch import Tensor +from torch import Tensor, nn from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape @@ -18,6 +18,11 @@ class NFlowsFlow(DensityEstimator): def __init__(self, net: Flow, condition_shape: torch.Size) -> None: super().__init__(net, condition_shape) + @property + def embedding_net(self) -> nn.Module: + r"""Return the embedding network.""" + return self.net._embedding_net + def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple i.e. batched conditions. diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py new file mode 100644 index 000000000..7cbb240a9 --- /dev/null +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -0,0 +1,150 @@ +from typing import Tuple + +import torch +from torch import Tensor, nn +from zuko.flows import Flow + +from sbi.neural_nets.density_estimators.base import DensityEstimator +from sbi.types import Shape + + +class ZukoFlow(DensityEstimator): + r"""`zuko`- based normalizing flow density estimator. + + Flow type objects already have a .log_prob() and .sample() method, so here we just + wrap them and add the .loss() method. + """ + + def __init__( + self, net: Flow, embedding_net: nn.Module, condition_shape: torch.Size + ): + r"""Initialize the density estimator. + + Args: + flow: Flow object. + condition_shape: Shape of the condition. + """ + + # assert len(condition_shape) == 1, "Zuko Flows require 1D conditions." + super().__init__(net=net, condition_shape=condition_shape) + self._embedding_net = embedding_net + + @property + def embedding_net(self) -> nn.Module: + r"""Return the embedding network.""" + return self._embedding_net + + def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: + r"""Return the log probabilities of the inputs given a condition or multiple + i.e. batched conditions. + + Args: + input: Inputs to evaluate the log probability on of shape + (*batch_shape1, input_size). + condition: Conditions of shape (*batch_shape2, *condition_shape). + + Raises: + RuntimeError: If batch_shape1 and batch_shape2 are not broadcastable. + + Returns: + Sample-wise log probabilities. + + Note: + This function should support PyTorch's automatic broadcasting. This means + the function should behave as follows for different input and condition + shapes: + - (input_size,) + (batch_size,*condition_shape) -> (batch_size,) + - (batch_size, input_size) + (*condition_shape) -> (batch_size,) + - (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,) + - (batch_size1, input_size) + (batch_size2, *condition_shape) + -> RuntimeError i.e. not broadcastable + - (batch_size1,1, input_size) + (batch_size2, *condition_shape) + -> (batch_size1,batch_size2) + - (batch_size1, input_size) + (batch_size2,1, *condition_shape) + -> (batch_size2,batch_size1) + """ + self._check_condition_shape(condition) + condition_dims = len(self._condition_shape) + + # PyTorch's automatic broadcasting + batch_shape_in = input.shape[:-1] + batch_shape_cond = condition.shape[:-condition_dims] + batch_shape = torch.broadcast_shapes(batch_shape_in, batch_shape_cond) + # Expand the input and condition to the same batch shape + input = input.expand(batch_shape + (input.shape[-1],)) + emb_cond = self._embedding_net(condition) + emb_cond = emb_cond.expand(batch_shape + (emb_cond.shape[-1],)) + + dists = self.net(emb_cond) + log_probs = dists.log_prob(input) + + return log_probs + + def loss(self, input: Tensor, condition: Tensor) -> Tensor: + r"""Return the loss for training the density estimator. + + Args: + input: Inputs to evaluate the loss on of shape (batch_size, input_size). + condition: Conditions of shape (batch_size, *condition_shape). + + Returns: + Negative log_probability (batch_size,) + """ + + return -self.log_prob(input, condition) + + def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: + r"""Return samples from the density estimator. + + Args: + sample_shape: Shape of the samples to return. + condition: Conditions of shape (*batch_shape, *condition_shape). + + Returns: + Samples of shape (*batch_shape, *sample_shape, input_size). + + Note: + This function should support batched conditions and should admit the + following behavior for different condition shapes: + - (*condition_shape) -> (*sample_shape, input_size) + - (*batch_shape, *condition_shape) + -> (*batch_shape, *sample_shape, input_size) + """ + self._check_condition_shape(condition) + + condition_dims = len(self._condition_shape) + batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () + + emb_cond = self._embedding_net(condition) + dists = self.net(emb_cond) + # zuko.sample() returns (*sample_shape, *batch_shape, input_size). + samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1) + + return samples + + def sample_and_log_prob( + self, sample_shape: torch.Size, condition: Tensor, **kwargs + ) -> Tuple[Tensor, Tensor]: + r"""Return samples and their density from the density estimator. + + Args: + sample_shape: Shape of the samples to return. + condition: Conditions of shape (*batch_shape, *condition_shape). + + Returns: + Samples and associated log probabilities. + """ + self._check_condition_shape(condition) + + condition_dims = len(self._condition_shape) + batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () + + emb_cond = self._embedding_net(condition) + dists = self.net(emb_cond) + samples, log_probs = dists.rsample_and_log_prob(sample_shape) + # zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...). + + samples = samples.reshape(*batch_shape, *sample_shape, -1) + log_probs = log_probs.reshape(*batch_shape, *sample_shape) + + return samples, log_probs diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 33a9443aa..3e7bd8d56 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -3,16 +3,18 @@ from functools import partial -from typing import Optional +from typing import Optional, Sequence, Union from warnings import warn import torch +import zuko from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets from pyknos.nflows.transforms.splines import rational_quadratic from torch import Tensor, nn, relu, tanh, tensor, uint8 +from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow from sbi.utils.sbiutils import ( standardizing_net, standardizing_transform, @@ -31,7 +33,7 @@ def build_made( num_mixture_components: int = 10, embedding_net: nn.Module = nn.Identity(), **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MADE p(x|y). Args: @@ -94,8 +96,9 @@ def build_made( ) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) - return neural_net + return flow def build_maf( @@ -110,7 +113,7 @@ def build_maf( dropout_probability: float = 0.0, use_batch_norm: bool = False, **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MAF p(x|y). Args: @@ -183,8 +186,9 @@ def build_maf( distribution = get_base_dist(x_numel, **kwargs) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) - return neural_net + return flow def build_maf_rqs( @@ -205,7 +209,7 @@ def build_maf_rqs( min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT, min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE, **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic splines (RQS). @@ -296,8 +300,9 @@ def build_maf_rqs( distribution = get_base_dist(x_numel, **kwargs) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) - return neural_net + return flow def build_nsf( @@ -315,7 +320,7 @@ def build_nsf( dropout_probability: float = 0.0, use_batch_norm: bool = False, **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds NSF p(x|y). Args: @@ -417,8 +422,100 @@ def mask_in_layer(i): # Combine transforms. transform = transforms.CompositeTransform(transform_list) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) - return neural_net + return flow + + +def build_zuko_maf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + residual: bool = True, + randperm: bool = False, + **kwargs, +) -> ZukoFlow: + """Builds MAF p(x|y). + + Args: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: Number of hidden features. + num_transforms: Number of transforms. + embedding_net: Optional embedding network for y. + residual: whether to use residual blocks in the coupling layer. + randperm: Whether features are randomly permuted between transformations or not. + kwargs: Additional arguments that are passed by the build function but are not + relevant for maf and are therefore ignored. + + Returns: + Neural network. + """ + x_numel = batch_x[0].numel() + # Infer the output dimensionality of the embedding_net by making a forward pass. + check_data_device(batch_x, batch_y) + check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) + y_numel = embedding_net(batch_y[:1]).numel() + if x_numel == 1: + warn( + "In one-dimensional output space, this flow is limited to Gaussians", + stacklevel=1, + ) + + if isinstance(hidden_features, int): + hidden_features = [hidden_features] * num_transforms + + if x_numel == 1: + maf = zuko.flows.MAF( + features=x_numel, + context=y_numel, + hidden_features=hidden_features, + transforms=num_transforms, + ) + else: + maf = zuko.flows.MAF( + features=x_numel, + context=y_numel, + hidden_features=hidden_features, + transforms=num_transforms, + randperm=randperm, + residual=residual, + ) + + transforms = maf.transform.transforms + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + # transforms = transforms + transforms = ( + *transforms, + standardizing_transform(batch_x, structured_x, backend="zuko"), + ) + + z_score_y_bool, structured_y = z_score_parser(z_score_y) + if z_score_y_bool: + # Prepend standardizing transform to y-embedding. + embedding_net = nn.Sequential( + standardizing_net(batch_y, structured_y), embedding_net + ) + + # Combine transforms. + neural_net = zuko.flows.Flow(transforms, maf.base) + + flow = ZukoFlow(neural_net, embedding_net, condition_shape=batch_y[0].shape) + + return flow class ContextSplineMap(nn.Module): diff --git a/sbi/neural_nets/mdn.py b/sbi/neural_nets/mdn.py index 4e4ef167f..9794edacc 100644 --- a/sbi/neural_nets/mdn.py +++ b/sbi/neural_nets/mdn.py @@ -8,6 +8,7 @@ from torch import Tensor, nn import sbi.utils as utils +from sbi.neural_nets.density_estimators import NFlowsFlow from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device @@ -20,7 +21,7 @@ def build_mdn( num_components: int = 10, embedding_net: nn.Module = nn.Identity(), **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MDN p(x|y). Args: @@ -80,5 +81,6 @@ def build_mdn( ) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) - return neural_net + return flow diff --git a/sbi/neural_nets/mnle.py b/sbi/neural_nets/mnle.py index ac63ec3a4..3ee0c89bc 100644 --- a/sbi/neural_nets/mnle.py +++ b/sbi/neural_nets/mnle.py @@ -309,7 +309,7 @@ def log_prob(self, x: Tensor, context: Tensor) -> Tensor: # Transform to log-space if needed. torch.log(cont_x) if self.log_transform_x else cont_x, # Pass parameters and discrete x as context. - context=torch.cat((context, disc_x), dim=1), + condition=torch.cat((context, disc_x), dim=1), ) # Combine into joint lp. @@ -382,7 +382,7 @@ def log_prob_iid(self, x: Tensor, theta: Tensor) -> Tensor: # Get repeat discrete data and theta to match in batch shape for flow eval. log_probs_cont = self.continuous_net.log_prob( torch.log(x_cont_repeated) if self.log_transform_x else x_cont_repeated, - context=torch.cat((theta_repeated, x_disc_repeated), dim=1), + condition=torch.cat((theta_repeated, x_disc_repeated), dim=1), ) # Combine into joint lp with first dim over trials. diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 18db941b9..5c78cd8dd 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -257,35 +257,31 @@ def accept_reject_sample( num_sampled_total, num_remaining = 0, num_samples accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False + # Ruff suggestion + if proposal_sampling_kwargs is None: + proposal_sampling_kwargs = {} # To cover cases with few samples without leakage: sampling_batch_size = min(num_samples, max_sampling_batch_size) while num_remaining > 0: # Sample and reject. - # This if-case is annoying, but it will be gone when we move away from - # nflows and towards a flows-framework which takes a tuple as sample_size. - if isinstance(proposal, nn.Module): - candidates = proposal.sample( - sampling_batch_size, - **proposal_sampling_kwargs or {}, # type: ignore - ).reshape(sampling_batch_size, -1) - else: - candidates = proposal.sample( - torch.Size((sampling_batch_size,)), - **proposal_sampling_kwargs or {}, # type: ignore - ) # type: ignore + candidates = proposal.sample( + (sampling_batch_size,), # type: ignore + **proposal_sampling_kwargs, + ) # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - samples = candidates[are_accepted] - accepted.append(samples) # Update. + # Note: For any condition of shape (*batch_shape, *condition_shape), the + # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work + # in dim = -2. num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[0] - pbar.update(samples.shape[0]) + num_remaining -= samples.shape[-2] + pbar.update(samples.shape[-2]) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. @@ -335,9 +331,9 @@ def accept_reject_sample( pbar.close() # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted)[:num_samples] + samples = torch.cat(accepted, dim=-2)[..., :num_samples, :] assert ( - samples.shape[0] == num_samples + samples.shape[-2] == num_samples ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index 0dbc58271..5f8df5d38 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -294,12 +294,16 @@ def __init__( self.device = self.potential_fn.device self.allow_iid_x = allow_iid_x - def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: + def __call__( + self, theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True + ) -> Tensor: r""" Returns the conditional potential $\log(p(\theta_i|\theta_j, x))$. Args: theta: Free parameters $\theta_i$, batch dimension 1. + x_o: Unused keyword argument. Only present to match the signature of the + `Potential` class. track_gradients: Whether to track gradients. Returns: diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index 71938df37..0bd96b6d1 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -11,7 +11,13 @@ build_mlp_classifier, build_resnet_classifier, ) -from sbi.neural_nets.flow import build_made, build_maf, build_maf_rqs, build_nsf +from sbi.neural_nets.flow import ( + build_made, + build_maf, + build_maf_rqs, + build_nsf, + build_zuko_maf, +) from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle @@ -168,6 +174,8 @@ def build_fn(batch_theta, batch_x): return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "mnle": return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_maf": + return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) else: raise NotImplementedError @@ -267,6 +275,8 @@ def build_fn(batch_theta, batch_x): return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "nsf": return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_maf": + return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) else: raise NotImplementedError diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index fb4b26b9b..4f9bcd980 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -11,6 +11,7 @@ import pyknos.nflows.transforms as transforms import torch import torch.distributions.transforms as torch_tf +import zuko from pyro.distributions import Empirical from torch import Tensor, ones, optim, zeros from torch import nn as nn @@ -136,8 +137,11 @@ def z_score_parser(z_score_flag: Optional["str"]) -> Tuple[bool, bool]: def standardizing_transform( - batch_t: Tensor, structured_dims: bool = False, min_std: float = 1e-14 -) -> transforms.AffineTransform: + batch_t: Tensor, + structured_dims: bool = False, + min_std: float = 1e-14, + backend: str = "nflows", +) -> Union[transforms.AffineTransform, zuko.transforms.MonotonicAffineTransform]: """Builds standardizing transform Args: @@ -171,7 +175,18 @@ def standardizing_transform( t_std = torch.std(batch_t[is_valid_t], dim=0) t_std[t_std < min_std] = min_std - return transforms.AffineTransform(shift=-t_mean / t_std, scale=1 / t_std) + if backend == "nflows": + return transforms.AffineTransform(shift=-t_mean / t_std, scale=1 / t_std) + elif backend == "zuko": + return zuko.flows.Unconditional( + zuko.transforms.MonotonicAffineTransform, + shift=-t_mean / t_std, + scale=1 / t_std, + buffer=True, + ) + + else: + raise ValueError("Invalid backend. Use 'nflows' or 'zuko'.") class Standardize(nn.Module): @@ -556,15 +571,8 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: # Try to check using the support property, use log prob method otherwise. try: sample_check = distribution.support.check(samples) + return sample_check - # Before torch v1.7.0, `support.check()` returned bools for every element. - # From v1.8.0 on, it directly considers all dimensions of a sample. E.g., - # for a single sample in 3D, v1.7.0 would return [[True, True, True]] and - # v1.8.0 would return [True]. - if sample_check.ndim > 1: - return torch.all(sample_check, dim=1) - else: - return sample_check # Falling back to log prob method of either the NeuralPosterior's net, or of a # custom wrapper distribution's. except (NotImplementedError, AttributeError): diff --git a/sbi/utils/user_input_checks_utils.py b/sbi/utils/user_input_checks_utils.py index 67bdd11a4..7f8acbd6e 100644 --- a/sbi/utils/user_input_checks_utils.py +++ b/sbi/utils/user_input_checks_utils.py @@ -368,7 +368,7 @@ def support(self): # Wrap as `independent` in order to have the correct shape of the # `log_abs_det`, i.e. summed over the parameter dimensions. return constraints.independent( - constraints.cat(supports, dim=1, lengths=self.dims_per_dist), + constraints.cat(supports, dim=-1, lengths=self.dims_per_dist), reinterpreted_batch_ndims=1, ) diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 86336b238..4d38ebb03 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -8,11 +8,11 @@ from torch import eye, zeros from torch.distributions import MultivariateNormal -from sbi.neural_nets.density_estimators.flow import NFlowsFlow -from sbi.neural_nets.flow import build_nsf +from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow +from sbi.neural_nets.flow import build_nsf, build_zuko_maf -@pytest.mark.parametrize("density_estimator", (NFlowsFlow,)) +@pytest.mark.parametrize("density_estimator", (NFlowsFlow, ZukoFlow)) @pytest.mark.parametrize("input_dims", (1, 2)) @pytest.mark.parametrize( "condition_shape", ((1,), (2,), (1, 1), (2, 2), (1, 1, 1), (2, 2, 2)) @@ -44,14 +44,22 @@ def forward(self, x): x = torch.sum(x, dim=-1) return x - net = build_nsf( - batch_input, - batch_context, - hidden_features=10, - num_transforms=2, - embedding_net=EmbeddingNet(), - ) - estimator = density_estimator(net, condition_shape) + if density_estimator == NFlowsFlow: + estimator = build_nsf( + batch_input, + batch_context, + hidden_features=10, + num_transforms=2, + embedding_net=EmbeddingNet(), + ) + elif density_estimator == ZukoFlow: + estimator = build_zuko_maf( + batch_input, + batch_context, + hidden_features=10, + num_transforms=2, + embedding_net=EmbeddingNet(), + ) # Loss is only required to work for batched inputs and contexts loss = estimator.loss(batch_input, batch_context) diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index 8cf97cf5f..5b05deb65 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -94,6 +94,7 @@ def test_embedding_api_with_multiple_trials(num_trials, num_dim): inference = SNPE(prior, density_estimator=density_estimator) _ = inference.append_simulations(theta, x).train(max_num_epochs=5) + posterior = inference.build_posterior().set_default_x(x_o) s = posterior.sample((1,)) diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 92f440e32..af1fc25f6 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -139,7 +139,10 @@ def test_c2st_snl_on_linear_gaussian_different_dims(model_str="maf"): @pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) @pytest.mark.parametrize("prior_str", ("uniform", "gaussian")) -def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: str): +@pytest.mark.parametrize("model_str", ("maf", "zuko_maf")) +def test_c2st_and_map_snl_on_linearGaussian_different( + num_dim: int, prior_str: str, model_str: str +): """Test SNL on linear Gaussian, comparing to ground truth posterior via c2st. Args: @@ -167,7 +170,7 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior, ) - density_estimator = likelihood_nn("maf", num_transforms=3) + density_estimator = likelihood_nn(model_str, num_transforms=3) inference = SNLE(density_estimator=density_estimator, show_progress_bars=False) theta, x = simulate_for_sbi( diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 527c41265..f906e2d48 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -147,8 +147,10 @@ def test_c2st_snpe_on_linearGaussian(snpe_method, num_dim: int, prior_str: str): @pytest.mark.slow -@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "maf_rqs", "nsf"]) -def test_density_estimators_on_linearGaussian(density_estimtor): +@pytest.mark.parametrize( + "density_estimator", ["mdn", "maf", "maf_rqs", "nsf", "zuko_maf"] +) +def test_density_estimators_on_linearGaussian(density_estimator): """Test SNPE with different density estimators on linear Gaussian example.""" theta_dim = 4 @@ -176,7 +178,7 @@ def test_density_estimators_on_linearGaussian(density_estimtor): prior, ) - inference = SNPE_C(prior, density_estimator=density_estimtor) + inference = SNPE_C(prior, density_estimator=density_estimator) theta, x = simulate_for_sbi( simulator, prior, num_simulations, simulation_batch_size=1000 @@ -190,7 +192,7 @@ def test_density_estimators_on_linearGaussian(density_estimtor): samples = posterior.sample((num_samples,)) # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg=f"snpe_{density_estimtor}") + check_c2st(samples, target_samples, alg=f"snpe_{density_estimator}") def test_c2st_snpe_on_linearGaussian_different_dims(density_estimator="maf"): @@ -573,7 +575,8 @@ def simulator(theta): proposal=restricted_prior, **mcmc_parameters, ) - cond_samples = mcmc_posterior.sample((num_conditional_samples,), x=x_o) + mcmc_posterior.set_default_x(x_o) # TODO: This test has a bug? Needed to add this + cond_samples = mcmc_posterior.sample((num_conditional_samples,)) _ = analysis.pairplot( cond_samples,