From 544044606c67b37127eaa22dde83dd692bd758ad Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Thu, 7 Mar 2024 15:07:27 +0100 Subject: [PATCH 01/35] build functions return density estimators --- sbi/neural_nets/flow.py | 22 +++++++++++++--------- sbi/neural_nets/mdn.py | 7 ++++--- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 49da8af77..69e1814f9 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -19,7 +19,7 @@ ) from sbi.utils.torchutils import create_alternating_binary_mask from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device - +from sbi.neural_nets.density_estimators import NFlowsFlow def build_made( batch_x: Tensor, @@ -30,7 +30,7 @@ def build_made( num_mixture_components: int = 10, embedding_net: nn.Module = nn.Identity(), **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MADE p(x|y). Args: @@ -90,8 +90,9 @@ def build_made( ) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) - return neural_net + return flow def build_maf( @@ -106,7 +107,7 @@ def build_maf( dropout_probability: float = 0.0, use_batch_norm: bool = False, **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MAF p(x|y). Args: @@ -176,8 +177,9 @@ def build_maf( distribution = distributions_.StandardNormal((x_numel,)) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) - return neural_net + return flow def build_maf_rqs( @@ -198,7 +200,7 @@ def build_maf_rqs( min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT, min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE, **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic splines (RQS). @@ -286,8 +288,9 @@ def build_maf_rqs( distribution = distributions_.StandardNormal((x_numel,)) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) - return neural_net + return flow def build_nsf( @@ -305,7 +308,7 @@ def build_nsf( dropout_probability: float = 0.0, use_batch_norm: bool = False, **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds NSF p(x|y). Args: @@ -407,8 +410,9 @@ def mask_in_layer(i): # Combine transforms. transform = transforms.CompositeTransform(transform_list) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) - return neural_net + return flow class ContextSplineMap(nn.Module): diff --git a/sbi/neural_nets/mdn.py b/sbi/neural_nets/mdn.py index 4e4ef167f..f5e1c4d24 100644 --- a/sbi/neural_nets/mdn.py +++ b/sbi/neural_nets/mdn.py @@ -9,7 +9,7 @@ import sbi.utils as utils from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device - +from sbi.neural_nets.density_estimators import NFlowsFlow def build_mdn( batch_x: Tensor, @@ -20,7 +20,7 @@ def build_mdn( num_components: int = 10, embedding_net: nn.Module = nn.Identity(), **kwargs, -) -> nn.Module: +) -> NFlowsFlow: """Builds MDN p(x|y). Args: @@ -80,5 +80,6 @@ def build_mdn( ) neural_net = flows.Flow(transform, distribution, embedding_net) + flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) - return neural_net + return flow From fdd6c05c9ee96fc97d92decbd0255099119fb43a Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Thu, 7 Mar 2024 15:08:14 +0100 Subject: [PATCH 02/35] snle_a uses DensityEstimator instead of nn.Module --- sbi/inference/base.py | 10 +++---- .../potentials/likelihood_based_potential.py | 23 ++++++++-------- sbi/inference/snle/mnle.py | 10 +++++++ sbi/inference/snle/snle_base.py | 27 ++++++++++--------- 4 files changed, 41 insertions(+), 29 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index a7353465d..d7cc01506 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -122,7 +122,7 @@ def __init__( self._prior = prior self._posterior = None - self._neural_net = None + self._density_estimator = None self._x_shape = None self._show_progress_bars = show_progress_bars @@ -349,20 +349,20 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """ converged = False - assert self._neural_net is not None - neural_net = self._neural_net + assert self._density_estimator is not None + density_estimator = self._density_estimator # (Re)-start the epoch count with the first epoch or any improvement. if epoch == 0 or self._val_log_prob > self._best_val_log_prob: self._best_val_log_prob = self._val_log_prob self._epochs_since_last_improvement = 0 - self._best_model_state_dict = deepcopy(neural_net.state_dict()) + self._best_model_state_dict = deepcopy(density_estimator.state_dict()) else: self._epochs_since_last_improvement += 1 # If no validation improvement over many epochs, stop training. if self._epochs_since_last_improvement > stop_after_epochs - 1: - neural_net.load_state_dict(self._best_model_state_dict) + density_estimator.load_state_dict(self._best_model_state_dict) converged = True return converged diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index d3aa5e51d..d337ee98f 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -13,10 +13,11 @@ from sbi.utils import mcmc_transform from sbi.utils.sbiutils import match_theta_and_x_batch_shapes from sbi.utils.torchutils import atleast_2d +from sbi.neural_nets.density_estimators import DensityEstimator def likelihood_estimator_based_potential( - likelihood_estimator: nn.Module, + likelihood_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], enable_transform: bool = True, @@ -27,7 +28,7 @@ def likelihood_estimator_based_potential( unconstrained space. Args: - likelihood_estimator: The neural network modelling the likelihood. + likelihood_estimator: The density estimator modelling the likelihood. prior: The prior distribution. x_o: The observed data at which to evaluate the likelihood. enable_transform: Whether to transform parameters to unconstrained space. @@ -55,7 +56,7 @@ class LikelihoodBasedPotential(BasePotential): def __init__( self, - likelihood_estimator: nn.Module, + likelihood_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", @@ -63,7 +64,7 @@ def __init__( r"""Returns the potential function for likelihood-based methods. Args: - likelihood_estimator: The neural network modelling the likelihood. + likelihood_estimator: The density estimator modelling the likelihood. prior: The prior distribution. x_o: The observed data at which to evaluate the likelihood. device: The device to which parameters and data are moved before evaluating @@ -92,7 +93,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: log_likelihood_trial_sum = _log_likelihoods_over_trials( x=self.x_o, theta=theta.to(self.device), - net=self.likelihood_estimator, + estimator=self.likelihood_estimator, track_gradients=track_gradients, ) @@ -100,7 +101,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: def _log_likelihoods_over_trials( - x: Tensor, theta: Tensor, net: Any, track_gradients: bool = False + x: Tensor, theta: Tensor, estimator: DensityEstimator, track_gradients: bool = False ) -> Tensor: r"""Return log likelihoods summed over iid trials of `x`. @@ -112,8 +113,8 @@ def _log_likelihoods_over_trials( Args: x: batch of iid data. - theta: batch of parameters - net: neural net with .log_prob() + theta: batch of parameters. + estimator: DensityEstimator. track_gradients: Whether to track gradients. Returns: @@ -131,13 +132,13 @@ def _log_likelihoods_over_trials( x_repeated.shape[0] == theta_repeated.shape[0] ), "x and theta must match in batch shape." assert ( - next(net.parameters()).device == x.device and x.device == theta.device - ), f"""device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device}, + next(estimator.parameters()).device == x.device and x.device == theta.device + ), f"""device mismatch: estimator, x, theta: {next(estimator.parameters()).device}, {x.device}, {theta.device}.""" # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): - log_likelihood_trial_batch = net.log_prob(x_repeated, theta_repeated) + log_likelihood_trial_batch = estimator.log_prob(x_repeated, theta_repeated) # Reshape to (x-trials x parameters), sum over trial-log likelihoods. log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( x.shape[0], -1 diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index 33b4ffe1e..57f0906a2 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Optional, Union from torch.distributions import Distribution +from torch import Tensor, nn, optim from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior from sbi.inference.potentials import mixed_likelihood_estimator_based_potential @@ -193,3 +194,12 @@ def build_posterior( self._model_bank.append(deepcopy(self._posterior)) return deepcopy(self._posterior) + + #Temporary: need to rewrite mixed likelihood estimators as DensityEstimator objects. + def _loss(self, theta: Tensor, x: Tensor) -> Tensor: + r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$. + + Returns: + Negative log prob. + """ + return -self._density_estimator.log_prob(x, context=theta) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index f4a27a6d0..4f44bea00 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -18,6 +18,7 @@ from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior from sbi.inference.potentials import likelihood_estimator_based_potential from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation +from sbi.neural_nets.density_estimators import DensityEstimator class LikelihoodEstimator(NeuralInference, ABC): @@ -126,7 +127,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, - ) -> flows.Flow: + ) -> DensityEstimator: r"""Train the density estimator to learn the distribution $p(x|\theta)$. Args: @@ -165,11 +166,11 @@ def train( # arguments, which will build the neural network # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. - if self._neural_net is None or retrain_from_scratch: + if self._density_estimator is None or retrain_from_scratch: # Get theta,x to initialize NN theta, x, _ = self.get_simulations(starting_round=start_idx) # Use only training data for building the neural net (z-scoring transforms) - self._neural_net = self._build_neural_net( + self._density_estimator = self._build_neural_net( theta[self.train_indices].to("cpu"), x[self.train_indices].to("cpu"), ) @@ -179,10 +180,10 @@ def train( len(self._x_shape) < 3 ), "SNLE cannot handle multi-dimensional simulator output." - self._neural_net.to(self._device) + self._density_estimator.to(self._device) if not resume_training: self.optimizer = optim.Adam( - list(self._neural_net.parameters()), + list(self._density_estimator.parameters()), lr=learning_rate, ) self.epoch, self._val_log_prob = 0, float("-Inf") @@ -191,7 +192,7 @@ def train( self.epoch, stop_after_epochs ): # Train for a single epoch. - self._neural_net.train() + self._density_estimator.train() train_log_probs_sum = 0 for batch in train_loader: self.optimizer.zero_grad() @@ -207,7 +208,7 @@ def train( train_loss.backward() if clip_max_norm is not None: clip_grad_norm_( - self._neural_net.parameters(), + self._density_estimator.parameters(), max_norm=clip_max_norm, ) self.optimizer.step() @@ -220,7 +221,7 @@ def train( self._summary["training_log_probs"].append(train_log_prob_average) # Calculate validation performance. - self._neural_net.eval() + self._density_estimator.eval() val_log_prob_sum = 0 with torch.no_grad(): for batch in val_loader: @@ -256,13 +257,13 @@ def train( # Avoid keeping the gradients in the resulting network, which can # cause memory leakage when benchmarking. - self._neural_net.zero_grad(set_to_none=True) + self._density_estimator.zero_grad(set_to_none=True) - return deepcopy(self._neural_net) + return deepcopy(self._density_estimator) def build_posterior( self, - density_estimator: Optional[nn.Module] = None, + density_estimator: Optional[DensityEstimator] = None, prior: Optional[Distribution] = None, sample_with: str = "mcmc", mcmc_method: str = "slice_np", @@ -311,7 +312,7 @@ def build_posterior( check_prior(prior) if density_estimator is None: - likelihood_estimator = self._neural_net + likelihood_estimator = self._density_estimator # If internal net is used device is defined. device = self._device else: @@ -367,4 +368,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor: Returns: Negative log prob. """ - return -self._neural_net.log_prob(x, context=theta) + return self._density_estimator.loss(x, condition=theta) From d8c126c645ced9512f2328d4be1192db535aa1a0 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Thu, 7 Mar 2024 18:47:02 +0100 Subject: [PATCH 03/35] build_nn context shape should be of unembedded context --- sbi/neural_nets/flow.py | 8 ++++---- sbi/neural_nets/mdn.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 69e1814f9..5d175d83d 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -90,7 +90,7 @@ def build_made( ) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) + flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) return flow @@ -177,7 +177,7 @@ def build_maf( distribution = distributions_.StandardNormal((x_numel,)) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) + flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) return flow @@ -288,7 +288,7 @@ def build_maf_rqs( distribution = distributions_.StandardNormal((x_numel,)) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) + flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) return flow @@ -410,7 +410,7 @@ def mask_in_layer(i): # Combine transforms. transform = transforms.CompositeTransform(transform_list) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) + flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) return flow diff --git a/sbi/neural_nets/mdn.py b/sbi/neural_nets/mdn.py index f5e1c4d24..6d47df31f 100644 --- a/sbi/neural_nets/mdn.py +++ b/sbi/neural_nets/mdn.py @@ -80,6 +80,6 @@ def build_mdn( ) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape) + flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) return flow From 537737cccdfa95a7c5d4053834de4dbd52b6cfc3 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Thu, 7 Mar 2024 18:47:57 +0100 Subject: [PATCH 04/35] zukoFlow density estimator and some tests --- pyproject.toml | 1 + .../density_estimators/__init__.py | 3 +- .../{flow.py => nflows_flow.py} | 2 +- .../density_estimators/zuko_flow.py | 138 ++++++++++++++++++ tests/density_estimator_test.py | 27 ++-- 5 files changed, 159 insertions(+), 12 deletions(-) rename sbi/neural_nets/density_estimators/{flow.py => nflows_flow.py} (99%) create mode 100644 sbi/neural_nets/density_estimators/zuko_flow.py diff --git a/pyproject.toml b/pyproject.toml index ca75649f4..ea6a3b1ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "tensorboard", "torch>=1.8.0", "tqdm", + "zuko>=1.0.0", ] [project.optional-dependencies] diff --git a/sbi/neural_nets/density_estimators/__init__.py b/sbi/neural_nets/density_estimators/__init__.py index 5069397e8..0cc2aa068 100644 --- a/sbi/neural_nets/density_estimators/__init__.py +++ b/sbi/neural_nets/density_estimators/__init__.py @@ -1,2 +1,3 @@ from sbi.neural_nets.density_estimators.base import DensityEstimator -from sbi.neural_nets.density_estimators.flow import NFlowsFlow +from sbi.neural_nets.density_estimators.nflows_flow import NFlowsFlow +from sbi.neural_nets.density_estimators.zuko_flow import ZukoFlow diff --git a/sbi/neural_nets/density_estimators/flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py similarity index 99% rename from sbi/neural_nets/density_estimators/flow.py rename to sbi/neural_nets/density_estimators/nflows_flow.py index 9ab94f2a9..2a2fb0a53 100644 --- a/sbi/neural_nets/density_estimators/flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -148,4 +148,4 @@ def sample_and_log_prob( samples = samples.reshape((*batch_shape, *sample_shape, -1)) log_probs = log_probs.reshape((*batch_shape, *sample_shape)) - return samples, log_probs + return samples, log_probs \ No newline at end of file diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py new file mode 100644 index 000000000..011ccae26 --- /dev/null +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -0,0 +1,138 @@ +from typing import Tuple + +import torch +from zuko.flows import Flow +from torch import Tensor + +from sbi.neural_nets.density_estimators.base import DensityEstimator +from sbi.types import Shape + +class ZukoFlow(DensityEstimator): + r"""`zuko`- based normalizing flow density estimator. + + Flow type objects already have a .log_prob() and .sample() method, so here we just + wrap them and add the .loss() method. + """ + def __init__(self, net: Flow, condition_shape: torch.Size): + r"""Initialize the density estimator. + + Args: + flow: Flow object. + condition_shape: Shape of the condition. + """ + + assert len(condition_shape) == 1, "Zuko Flows require 1D conditions." + super().__init__(net=net, condition_shape=condition_shape) + + + def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: + r"""Return the log probabilities of the inputs given a condition or multiple + i.e. batched conditions. + + Args: + input: Inputs to evaluate the log probability on of shape + (*batch_shape1, input_size). + condition: Conditions of shape (*batch_shape2, *condition_shape). + + Raises: + RuntimeError: If batch_shape1 and batch_shape2 are not broadcastable. + + Returns: + Sample-wise log probabilities. + + Note: + This function should support PyTorch's automatic broadcasting. This means + the function should behave as follows for different input and condition + shapes: + - (input_size,) + (batch_size,*condition_shape) -> (batch_size,) + - (batch_size, input_size) + (*condition_shape) -> (batch_size,) + - (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,) + - (batch_size1, input_size) + (batch_size2, *condition_shape) + -> RuntimeError i.e. not broadcastable + - (batch_size1,1, input_size) + (batch_size2, *condition_shape) + -> (batch_size1,batch_size2) + - (batch_size1, input_size) + (batch_size2,1, *condition_shape) + -> (batch_size2,batch_size1) + """ + self._check_condition_shape(condition) + condition_dims = len(self._condition_shape) + + # PyTorch's automatic broadcasting + batch_shape_in = input.shape[:-1] + batch_shape_cond = condition.shape[:-condition_dims] + batch_shape = torch.broadcast_shapes(batch_shape_in, batch_shape_cond) + # Expand the input and condition to the same batch shape + input = input.expand(batch_shape + (input.shape[-1],)) + condition = condition.expand(batch_shape + self._condition_shape) + + dists = self.net(condition) + log_probs = dists.log_prob(input) + + return log_probs + + def loss(self, input: Tensor, condition: Tensor) -> Tensor: + r"""Return the loss for training the density estimator. + + Args: + input: Inputs to evaluate the loss on of shape (batch_size, input_size). + condition: Conditions of shape (batch_size, *condition_shape). + + Returns: + Negative log_probability (batch_size,) + """ + + return -self.log_prob(input, condition) + + def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: + r"""Return samples from the density estimator. + + Args: + sample_shape: Shape of the samples to return. + condition: Conditions of shape (*batch_shape, *condition_shape). + + Returns: + Samples of shape (*batch_shape, *sample_shape, input_size). + + Note: + This function should support batched conditions and should admit the + following behavior for different condition shapes: + - (*condition_shape) -> (*sample_shape, input_size) + - (*batch_shape, *condition_shape) + -> (*batch_shape, *sample_shape, input_size) + """ + self._check_condition_shape(condition) + + + condition_dims = len(self._condition_shape) + batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () + + dists = self.net(condition) + # zuko.sample() returns (*sample_shape, *batch_shape, input_size). + samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1) + + return samples + + def sample_and_log_prob( + self, sample_shape: torch.Size, condition: Tensor, **kwargs + ) -> Tuple[Tensor, Tensor]: + r"""Return samples and their density from the density estimator. + + Args: + sample_shape: Shape of the samples to return. + condition: Conditions of shape (*batch_shape, *condition_shape). + + Returns: + Samples and associated log probabilities. + """ + condition_dims = len(self._condition_shape) + batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () + + dists = self.net(condition) + samples,log_probs = dists.rsample_and_log_prob(sample_shape) + # zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...). + + samples = samples.reshape(*batch_shape, *sample_shape, -1) + log_probs = log_probs.reshape(*batch_shape, *sample_shape) + + + return samples, log_probs diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 35e877f7a..d82fd30b6 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -8,11 +8,13 @@ from torch import eye, zeros from torch.distributions import MultivariateNormal -from sbi.neural_nets.density_estimators.flow import NFlowsFlow +from sbi.neural_nets.density_estimators import NFlowsFlow,ZukoFlow from sbi.neural_nets.flow import build_nsf +from zuko.flows import NSF -@pytest.mark.parametrize("density_estimator", (NFlowsFlow,)) + +@pytest.mark.parametrize("density_estimator", (NFlowsFlow,ZukoFlow)) @pytest.mark.parametrize("input_dims", (1, 2)) @pytest.mark.parametrize( "condition_shape", ((1,), (2,), (1, 1), (2, 2), (1, 1, 1), (2, 2, 2)) @@ -44,14 +46,19 @@ def forward(self, x): x = torch.sum(x, dim=-1) return x - net = build_nsf( - batch_input, - batch_context, - hidden_features=10, - num_transforms=2, - embedding_net=EmbeddingNet(), - ) - estimator = density_estimator(net, condition_shape) + if density_estimator == NFlowsFlow: + estimator = build_nsf( + batch_input, + batch_context, + hidden_features=10, + num_transforms=2, + embedding_net=EmbeddingNet(), + ) + elif density_estimator == ZukoFlow: + if len(condition_shape) > 1: + pytest.skip("ZukoFlow does not support multi-dimensional contexts.") + net = NSF(features=input_dims,context=condition_shape[-1],transforms=2,hidden_features=(10,),bins=8) + estimator = density_estimator(net, condition_shape) # Loss is only required to work for batched inputs and contexts loss = estimator.loss(batch_input, batch_context) From fc00b5b4bd6835c1920b8f7dc03555be66bc0bde Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Thu, 7 Mar 2024 18:55:04 +0100 Subject: [PATCH 05/35] formatting --- .../potentials/likelihood_based_potential.py | 2 +- sbi/inference/snle/mnle.py | 11 ++++------- sbi/inference/snle/snle_base.py | 2 +- sbi/neural_nets/density_estimators/nflows_flow.py | 2 +- sbi/neural_nets/density_estimators/zuko_flow.py | 11 +++++------ sbi/neural_nets/flow.py | 11 ++++++----- sbi/neural_nets/mdn.py | 5 +++-- tests/density_estimator_test.py | 15 ++++++++++----- 8 files changed, 31 insertions(+), 28 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index d337ee98f..7fb9196c7 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -8,12 +8,12 @@ from torch.distributions import Distribution from sbi.inference.potentials.base_potential import BasePotential +from sbi.neural_nets.density_estimators import DensityEstimator from sbi.neural_nets.mnle import MixedDensityEstimator from sbi.types import TorchTransform from sbi.utils import mcmc_transform from sbi.utils.sbiutils import match_theta_and_x_batch_shapes from sbi.utils.torchutils import atleast_2d -from sbi.neural_nets.density_estimators import DensityEstimator def likelihood_estimator_based_potential( diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index 57f0906a2..6a3c3940c 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -5,8 +5,8 @@ from copy import deepcopy from typing import Any, Callable, Dict, Optional, Union -from torch.distributions import Distribution from torch import Tensor, nn, optim +from torch.distributions import Distribution from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior from sbi.inference.potentials import mixed_likelihood_estimator_based_potential @@ -152,10 +152,7 @@ def build_posterior( ), f"""net must be of type MixedDensityEstimator but is {type (likelihood_estimator)}.""" - ( - potential_fn, - theta_transform, - ) = mixed_likelihood_estimator_based_potential( + (potential_fn, theta_transform,) = mixed_likelihood_estimator_based_potential( likelihood_estimator=likelihood_estimator, prior=prior, x_o=None ) @@ -194,8 +191,8 @@ def build_posterior( self._model_bank.append(deepcopy(self._posterior)) return deepcopy(self._posterior) - - #Temporary: need to rewrite mixed likelihood estimators as DensityEstimator objects. + + # Temporary: need to rewrite mixed likelihood estimators as DensityEstimator objects. def _loss(self, theta: Tensor, x: Tensor) -> Tensor: r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$. diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 4f44bea00..c79f923c3 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -17,8 +17,8 @@ from sbi.inference import NeuralInference from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior from sbi.inference.potentials import likelihood_estimator_based_potential -from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation from sbi.neural_nets.density_estimators import DensityEstimator +from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation class LikelihoodEstimator(NeuralInference, ABC): diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 2a2fb0a53..9ab94f2a9 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -148,4 +148,4 @@ def sample_and_log_prob( samples = samples.reshape((*batch_shape, *sample_shape, -1)) log_probs = log_probs.reshape((*batch_shape, *sample_shape)) - return samples, log_probs \ No newline at end of file + return samples, log_probs diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index 011ccae26..bef9adc13 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -1,18 +1,20 @@ from typing import Tuple import torch -from zuko.flows import Flow from torch import Tensor +from zuko.flows import Flow from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape + class ZukoFlow(DensityEstimator): r"""`zuko`- based normalizing flow density estimator. Flow type objects already have a .log_prob() and .sample() method, so here we just wrap them and add the .loss() method. """ + def __init__(self, net: Flow, condition_shape: torch.Size): r"""Initialize the density estimator. @@ -23,7 +25,6 @@ def __init__(self, net: Flow, condition_shape: torch.Size): assert len(condition_shape) == 1, "Zuko Flows require 1D conditions." super().__init__(net=net, condition_shape=condition_shape) - def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple @@ -64,7 +65,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: # Expand the input and condition to the same batch shape input = input.expand(batch_shape + (input.shape[-1],)) condition = condition.expand(batch_shape + self._condition_shape) - + dists = self.net(condition) log_probs = dists.log_prob(input) @@ -102,7 +103,6 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: """ self._check_condition_shape(condition) - condition_dims = len(self._condition_shape) batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () @@ -128,11 +128,10 @@ def sample_and_log_prob( batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () dists = self.net(condition) - samples,log_probs = dists.rsample_and_log_prob(sample_shape) + samples, log_probs = dists.rsample_and_log_prob(sample_shape) # zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...). samples = samples.reshape(*batch_shape, *sample_shape, -1) log_probs = log_probs.reshape(*batch_shape, *sample_shape) - return samples, log_probs diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 5d175d83d..c53b2952e 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -12,6 +12,7 @@ from pyknos.nflows.transforms.splines import rational_quadratic from torch import Tensor, nn, relu, tanh, tensor, uint8 +from sbi.neural_nets.density_estimators import NFlowsFlow from sbi.utils.sbiutils import ( standardizing_net, standardizing_transform, @@ -19,7 +20,7 @@ ) from sbi.utils.torchutils import create_alternating_binary_mask from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device -from sbi.neural_nets.density_estimators import NFlowsFlow + def build_made( batch_x: Tensor, @@ -90,7 +91,7 @@ def build_made( ) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) return flow @@ -177,7 +178,7 @@ def build_maf( distribution = distributions_.StandardNormal((x_numel,)) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) return flow @@ -288,7 +289,7 @@ def build_maf_rqs( distribution = distributions_.StandardNormal((x_numel,)) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) return flow @@ -410,7 +411,7 @@ def mask_in_layer(i): # Combine transforms. transform = transforms.CompositeTransform(transform_list) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) return flow diff --git a/sbi/neural_nets/mdn.py b/sbi/neural_nets/mdn.py index 6d47df31f..9794edacc 100644 --- a/sbi/neural_nets/mdn.py +++ b/sbi/neural_nets/mdn.py @@ -8,8 +8,9 @@ from torch import Tensor, nn import sbi.utils as utils -from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device from sbi.neural_nets.density_estimators import NFlowsFlow +from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device + def build_mdn( batch_x: Tensor, @@ -80,6 +81,6 @@ def build_mdn( ) neural_net = flows.Flow(transform, distribution, embedding_net) - flow = NFlowsFlow(neural_net,condition_shape = batch_y[0].shape) + flow = NFlowsFlow(neural_net, condition_shape=batch_y[0].shape) return flow diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index d82fd30b6..d4d5b976b 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -7,14 +7,13 @@ import torch from torch import eye, zeros from torch.distributions import MultivariateNormal +from zuko.flows import NSF -from sbi.neural_nets.density_estimators import NFlowsFlow,ZukoFlow +from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow from sbi.neural_nets.flow import build_nsf -from zuko.flows import NSF - -@pytest.mark.parametrize("density_estimator", (NFlowsFlow,ZukoFlow)) +@pytest.mark.parametrize("density_estimator", (NFlowsFlow, ZukoFlow)) @pytest.mark.parametrize("input_dims", (1, 2)) @pytest.mark.parametrize( "condition_shape", ((1,), (2,), (1, 1), (2, 2), (1, 1, 1), (2, 2, 2)) @@ -57,7 +56,13 @@ def forward(self, x): elif density_estimator == ZukoFlow: if len(condition_shape) > 1: pytest.skip("ZukoFlow does not support multi-dimensional contexts.") - net = NSF(features=input_dims,context=condition_shape[-1],transforms=2,hidden_features=(10,),bins=8) + net = NSF( + features=input_dims, + context=condition_shape[-1], + transforms=2, + hidden_features=(10,), + bins=8, + ) estimator = density_estimator(net, condition_shape) # Loss is only required to work for batched inputs and contexts From c2c1b477f954d7441dbcf310ec816c0b69b7d799 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Fri, 8 Mar 2024 16:09:23 +0100 Subject: [PATCH 06/35] revert estimator of base NeuralInference back to self._neural_net --- sbi/inference/base.py | 10 +++++----- sbi/inference/snle/mnle.py | 2 +- sbi/inference/snle/snle_base.py | 22 +++++++++++----------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index d7cc01506..a7353465d 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -122,7 +122,7 @@ def __init__( self._prior = prior self._posterior = None - self._density_estimator = None + self._neural_net = None self._x_shape = None self._show_progress_bars = show_progress_bars @@ -349,20 +349,20 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """ converged = False - assert self._density_estimator is not None - density_estimator = self._density_estimator + assert self._neural_net is not None + neural_net = self._neural_net # (Re)-start the epoch count with the first epoch or any improvement. if epoch == 0 or self._val_log_prob > self._best_val_log_prob: self._best_val_log_prob = self._val_log_prob self._epochs_since_last_improvement = 0 - self._best_model_state_dict = deepcopy(density_estimator.state_dict()) + self._best_model_state_dict = deepcopy(neural_net.state_dict()) else: self._epochs_since_last_improvement += 1 # If no validation improvement over many epochs, stop training. if self._epochs_since_last_improvement > stop_after_epochs - 1: - density_estimator.load_state_dict(self._best_model_state_dict) + neural_net.load_state_dict(self._best_model_state_dict) converged = True return converged diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index 6a3c3940c..d67bc7921 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -199,4 +199,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor: Returns: Negative log prob. """ - return -self._density_estimator.log_prob(x, context=theta) + return -self._neural_net.log_prob(x, context=theta) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index c79f923c3..6c8dc1524 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -166,11 +166,11 @@ def train( # arguments, which will build the neural network # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. - if self._density_estimator is None or retrain_from_scratch: + if self._neural_net is None or retrain_from_scratch: # Get theta,x to initialize NN theta, x, _ = self.get_simulations(starting_round=start_idx) # Use only training data for building the neural net (z-scoring transforms) - self._density_estimator = self._build_neural_net( + self._neural_net = self._build_neural_net( theta[self.train_indices].to("cpu"), x[self.train_indices].to("cpu"), ) @@ -180,10 +180,10 @@ def train( len(self._x_shape) < 3 ), "SNLE cannot handle multi-dimensional simulator output." - self._density_estimator.to(self._device) + self._neural_net.to(self._device) if not resume_training: self.optimizer = optim.Adam( - list(self._density_estimator.parameters()), + list(self._neural_net.parameters()), lr=learning_rate, ) self.epoch, self._val_log_prob = 0, float("-Inf") @@ -192,7 +192,7 @@ def train( self.epoch, stop_after_epochs ): # Train for a single epoch. - self._density_estimator.train() + self._neural_net.train() train_log_probs_sum = 0 for batch in train_loader: self.optimizer.zero_grad() @@ -208,7 +208,7 @@ def train( train_loss.backward() if clip_max_norm is not None: clip_grad_norm_( - self._density_estimator.parameters(), + self._neural_net.parameters(), max_norm=clip_max_norm, ) self.optimizer.step() @@ -221,7 +221,7 @@ def train( self._summary["training_log_probs"].append(train_log_prob_average) # Calculate validation performance. - self._density_estimator.eval() + self._neural_net.eval() val_log_prob_sum = 0 with torch.no_grad(): for batch in val_loader: @@ -257,9 +257,9 @@ def train( # Avoid keeping the gradients in the resulting network, which can # cause memory leakage when benchmarking. - self._density_estimator.zero_grad(set_to_none=True) + self._neural_net.zero_grad(set_to_none=True) - return deepcopy(self._density_estimator) + return deepcopy(self._neural_net) def build_posterior( self, @@ -312,7 +312,7 @@ def build_posterior( check_prior(prior) if density_estimator is None: - likelihood_estimator = self._density_estimator + likelihood_estimator = self._neural_net # If internal net is used device is defined. device = self._device else: @@ -368,4 +368,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor: Returns: Negative log prob. """ - return self._density_estimator.loss(x, condition=theta) + return self._neural_net.loss(x, condition=theta) From d240d40cf3ce8b1ef4db5fe8263a25d8c73d3192 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 11:24:50 +0100 Subject: [PATCH 07/35] NPE integartion first approach --- sbi/inference/base.py | 6 ++--- sbi/inference/posteriors/direct_posterior.py | 19 ++++++++-------- .../potentials/likelihood_based_potential.py | 2 +- .../potentials/posterior_based_potential.py | 14 +++++++----- sbi/inference/snle/mnle.py | 7 ++++-- sbi/inference/snle/snle_base.py | 22 +++++++++---------- sbi/inference/snpe/snpe_a.py | 2 +- sbi/inference/snpe/snpe_base.py | 5 +++-- sbi/neural_nets/mnle.py | 3 ++- sbi/samplers/rejection/rejection.py | 17 +++++--------- tests/embedding_net_test.py | 3 ++- 11 files changed, 51 insertions(+), 49 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index d7cc01506..19bed6ce2 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -122,7 +122,7 @@ def __init__( self._prior = prior self._posterior = None - self._density_estimator = None + self._neural_net = None self._x_shape = None self._show_progress_bars = show_progress_bars @@ -349,8 +349,8 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """ converged = False - assert self._density_estimator is not None - density_estimator = self._density_estimator + assert self._neural_net is not None + density_estimator = self._neural_net # (Re)-start the epoch count with the first epoch or any improvement. if epoch == 0 or self._val_log_prob > self._best_val_log_prob: diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index e1cf8dd4e..5a292c357 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -11,6 +11,7 @@ from sbi.inference.potentials.posterior_based_potential import ( posterior_estimator_based_potential, ) +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.samplers.rejection.rejection import accept_reject_sample from sbi.types import Shape from sbi.utils import check_prior, match_theta_and_x_batch_shapes, within_support @@ -33,7 +34,7 @@ class DirectPosterior(NeuralPosterior): def __init__( self, - posterior_estimator: flows.Flow, + posterior_estimator: DensityEstimator, prior: Distribution, max_sampling_batch_size: int = 10_000, device: Optional[str] = None, @@ -121,9 +122,11 @@ def sample( num_samples=num_samples, show_progress_bars=show_progress_bars, max_sampling_batch_size=max_sampling_batch_size, - proposal_sampling_kwargs={"context": x}, + proposal_sampling_kwargs={"condition": x}, alternative_method="build_posterior(..., sample_with='mcmc')", )[0] + samples = samples.view(sample_shape + (-1,)) # TODO: Why is this necessary? + return samples def log_prob( @@ -163,17 +166,15 @@ def log_prob( # TODO Train exited here, entered after sampling? self.posterior_estimator.eval() - theta = ensure_theta_batched(torch.as_tensor(theta)) - theta_repeated, x_repeated = match_theta_and_x_batch_shapes(theta, x) + # theta = ensure_theta_batched(torch.as_tensor(theta)) + # theta_repeated, x_repeated = match_theta_and_x_batch_shapes(theta, x) with torch.set_grad_enabled(track_gradients): # Evaluate on device, move back to cpu for comparison with prior. - unnorm_log_prob = self.posterior_estimator.log_prob( - theta_repeated, context=x_repeated - ) + unnorm_log_prob = self.posterior_estimator.log_prob(theta, condition=x) # Force probability to be zero outside prior support. - in_prior_support = within_support(self.prior, theta_repeated) + in_prior_support = within_support(self.prior, theta) masked_log_prob = torch.where( in_prior_support, @@ -227,7 +228,7 @@ def acceptance_at(x: Tensor) -> Tensor: show_progress_bars=show_progress_bars, sample_for_correction_factor=True, max_sampling_batch_size=rejection_sampling_batch_size, - proposal_sampling_kwargs={"context": x}, + proposal_sampling_kwargs={"condition": x}, )[1] # Check if the provided x matches the default x (short-circuit on identity). diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 7fb9196c7..35a898d49 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -180,7 +180,7 @@ def mixed_likelihood_estimator_based_potential( class MixedLikelihoodBasedPotential(LikelihoodBasedPotential): def __init__( self, - likelihood_estimator: MixedDensityEstimator, + likelihood_estimator: MixedDensityEstimator, # type: ignore TODO fix pyright prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 8582041de..6f5bdd1e8 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -9,6 +9,7 @@ from torch.distributions import Distribution from sbi.inference.potentials.base_potential import BasePotential +from sbi.neural_nets.density_estimators import DensityEstimator from sbi.types import TorchTransform from sbi.utils import mcmc_transform from sbi.utils.sbiutils import match_theta_and_x_batch_shapes, within_support @@ -16,7 +17,7 @@ def posterior_estimator_based_potential( - posterior_estimator: nn.Module, + posterior_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], enable_transform: bool = True, @@ -59,7 +60,7 @@ class PosteriorBasedPotential(BasePotential): def __init__( self, - posterior_estimator: flows.Flow, + posterior_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", @@ -92,13 +93,14 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: The potential. """ - theta = ensure_theta_batched(torch.as_tensor(theta)) - theta, x_repeated = match_theta_and_x_batch_shapes(theta, self.x_o) - theta, x_repeated = theta.to(self.device), x_repeated.to(self.device) + # NOTE: This is no longer necessary, as the `log_prob` will broadcast + # theta = ensure_theta_batched(torch.as_tensor(theta)) + # theta, x_repeated = match_theta_and_x_batch_shapes(theta, self.x_o) + # theta, x_repeated = theta.to(self.device), x_repeated.to(self.device) with torch.set_grad_enabled(track_gradients): posterior_log_prob = self.posterior_estimator.log_prob( - theta, context=x_repeated + theta, condition=self._x_o ) # Force probability to be zero outside prior support. diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index 6a3c3940c..fd2d9cdd2 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -152,7 +152,10 @@ def build_posterior( ), f"""net must be of type MixedDensityEstimator but is {type (likelihood_estimator)}.""" - (potential_fn, theta_transform,) = mixed_likelihood_estimator_based_potential( + ( + potential_fn, + theta_transform, + ) = mixed_likelihood_estimator_based_potential( likelihood_estimator=likelihood_estimator, prior=prior, x_o=None ) @@ -199,4 +202,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor: Returns: Negative log prob. """ - return -self._density_estimator.log_prob(x, context=theta) + return -self._neural_net.log_prob(x, context=theta) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index c79f923c3..6c8dc1524 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -166,11 +166,11 @@ def train( # arguments, which will build the neural network # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. - if self._density_estimator is None or retrain_from_scratch: + if self._neural_net is None or retrain_from_scratch: # Get theta,x to initialize NN theta, x, _ = self.get_simulations(starting_round=start_idx) # Use only training data for building the neural net (z-scoring transforms) - self._density_estimator = self._build_neural_net( + self._neural_net = self._build_neural_net( theta[self.train_indices].to("cpu"), x[self.train_indices].to("cpu"), ) @@ -180,10 +180,10 @@ def train( len(self._x_shape) < 3 ), "SNLE cannot handle multi-dimensional simulator output." - self._density_estimator.to(self._device) + self._neural_net.to(self._device) if not resume_training: self.optimizer = optim.Adam( - list(self._density_estimator.parameters()), + list(self._neural_net.parameters()), lr=learning_rate, ) self.epoch, self._val_log_prob = 0, float("-Inf") @@ -192,7 +192,7 @@ def train( self.epoch, stop_after_epochs ): # Train for a single epoch. - self._density_estimator.train() + self._neural_net.train() train_log_probs_sum = 0 for batch in train_loader: self.optimizer.zero_grad() @@ -208,7 +208,7 @@ def train( train_loss.backward() if clip_max_norm is not None: clip_grad_norm_( - self._density_estimator.parameters(), + self._neural_net.parameters(), max_norm=clip_max_norm, ) self.optimizer.step() @@ -221,7 +221,7 @@ def train( self._summary["training_log_probs"].append(train_log_prob_average) # Calculate validation performance. - self._density_estimator.eval() + self._neural_net.eval() val_log_prob_sum = 0 with torch.no_grad(): for batch in val_loader: @@ -257,9 +257,9 @@ def train( # Avoid keeping the gradients in the resulting network, which can # cause memory leakage when benchmarking. - self._density_estimator.zero_grad(set_to_none=True) + self._neural_net.zero_grad(set_to_none=True) - return deepcopy(self._density_estimator) + return deepcopy(self._neural_net) def build_posterior( self, @@ -312,7 +312,7 @@ def build_posterior( check_prior(prior) if density_estimator is None: - likelihood_estimator = self._density_estimator + likelihood_estimator = self._neural_net # If internal net is used device is defined. device = self._device else: @@ -368,4 +368,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor: Returns: Negative log prob. """ - return self._density_estimator.loss(x, condition=theta) + return self._neural_net.loss(x, condition=theta) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 98a19545d..3a29ffbe2 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -500,7 +500,7 @@ def _sample_approx_posterior_mog( embedded_context = torchutils.repeat_rows( embedded_context, num_reps=num_samples ) - + # TODO theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) if embedded_context is not None: diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 3d668b111..fbe9924cc 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -22,6 +22,7 @@ ) from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.potentials import posterior_estimator_based_potential +from sbi.neural_nets.density_estimators import DensityEstimator from sbi.utils import ( RestrictedPrior, check_estimator_arg, @@ -218,7 +219,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[dict] = None, - ) -> nn.Module: + ) -> DensityEstimator: r"""Return density estimator that approximates the distribution $p(\theta|x)$. Args: @@ -431,7 +432,7 @@ def default_calibration_kernel(x): def build_posterior( self, - density_estimator: Optional[nn.Module] = None, + density_estimator: Optional[DensityEstimator] = None, prior: Optional[Distribution] = None, sample_with: str = "direct", mcmc_method: str = "slice_np", diff --git a/sbi/neural_nets/mnle.py b/sbi/neural_nets/mnle.py index 52022df40..27c2d58b4 100644 --- a/sbi/neural_nets/mnle.py +++ b/sbi/neural_nets/mnle.py @@ -10,6 +10,7 @@ from torch.distributions import Categorical from torch.nn import Sigmoid, Softmax +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.neural_nets.flow import build_nsf from sbi.utils.sbiutils import match_theta_and_x_batch_shapes, standardizing_net from sbi.utils.torchutils import atleast_2d @@ -216,7 +217,7 @@ class MixedDensityEstimator(nn.Module): def __init__( self, discrete_net: CategoricalNet, - continuous_net: flows.Flow, + continuous_net: DensityEstimator, log_transform_x: bool = False, ): """Initialize class for combining density estimators for MNLE. diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 2b8c82821..497532d20 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -259,16 +259,9 @@ def accept_reject_sample( sampling_batch_size = min(num_samples, max_sampling_batch_size) while num_remaining > 0: # Sample and reject. - # This if-case is annoying, but it will be gone when we move away from - # nflows and towards a flows-framework which takes a tuple as sample_size. - if isinstance(proposal, nn.Module): - candidates = proposal.sample( - sampling_batch_size, **proposal_sampling_kwargs # type: ignore - ).reshape(sampling_batch_size, -1) - else: - candidates = proposal.sample( - (sampling_batch_size,), **proposal_sampling_kwargs # type: ignore - ) # type: ignore + candidates = proposal.sample( + (sampling_batch_size,), **proposal_sampling_kwargs # type: ignore + ) # type: ignore # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) @@ -330,9 +323,9 @@ def accept_reject_sample( pbar.close() # When in case of leakage a batch size was used there could be too many samples. - samples = torch.cat(accepted)[:num_samples] + samples = torch.cat(accepted, dim=-2)[..., :num_samples, :] assert ( - samples.shape[0] == num_samples + samples.shape[-2] == num_samples ), "Number of accepted samples must match required samples." return samples, as_tensor(acceptance_rate) diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index 8cf97cf5f..6d568608c 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -59,7 +59,7 @@ def test_embedding_net_api(method, num_dim: int, embedding_net: str): else: raise NameError - _ = inference.append_simulations(theta, x).train(max_num_epochs=2) + _ = inference.append_simulations(theta, x).train(max_num_epochs=5) posterior = inference.build_posterior( mcmc_method="slice_np_vectorized", mcmc_parameters=dict(num_chains=2, warmup_steps=10, thin=5), @@ -94,6 +94,7 @@ def test_embedding_api_with_multiple_trials(num_trials, num_dim): inference = SNPE(prior, density_estimator=density_estimator) _ = inference.append_simulations(theta, x).train(max_num_epochs=5) + # Increased to 5 as otherwise 99.99 rejection rate for SNPE posterior = inference.build_posterior().set_default_x(x_o) s = posterior.sample((1,)) From d57030a2bc349b9ec083caa2d22ce4dd6e41a49f Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 11:59:03 +0100 Subject: [PATCH 08/35] Fix rejection bug, ignore pyright for mnle for the moment --- .../potentials/likelihood_based_potential.py | 5 +++-- sbi/inference/potentials/posterior_based_potential.py | 1 + sbi/neural_nets/mnle.py | 2 +- sbi/samplers/rejection/rejection.py | 11 ++++++----- sbi/utils/sbiutils.py | 11 ++++++----- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 35a898d49..8a160df91 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -180,12 +180,13 @@ def mixed_likelihood_estimator_based_potential( class MixedLikelihoodBasedPotential(LikelihoodBasedPotential): def __init__( self, - likelihood_estimator: MixedDensityEstimator, # type: ignore TODO fix pyright + likelihood_estimator: MixedDensityEstimator, # type: ignore TODO fix pyright prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", ): - super().__init__(likelihood_estimator, prior, x_o, device) + # TODO Fix pyright issue by making MixedDensityEstimator a subclass of DensityEstimator + super().__init__(likelihood_estimator, prior, x_o, device) # type: ignore def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: # Calculate likelihood in one batch. diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 6f5bdd1e8..14bd4fb41 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -97,6 +97,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: # theta = ensure_theta_batched(torch.as_tensor(theta)) # theta, x_repeated = match_theta_and_x_batch_shapes(theta, self.x_o) # theta, x_repeated = theta.to(self.device), x_repeated.to(self.device) + assert self._x_o is not None, "No observed data is available." with torch.set_grad_enabled(track_gradients): posterior_log_prob = self.posterior_estimator.log_prob( diff --git a/sbi/neural_nets/mnle.py b/sbi/neural_nets/mnle.py index 27c2d58b4..833ce921c 100644 --- a/sbi/neural_nets/mnle.py +++ b/sbi/neural_nets/mnle.py @@ -217,7 +217,7 @@ class MixedDensityEstimator(nn.Module): def __init__( self, discrete_net: CategoricalNet, - continuous_net: DensityEstimator, + continuous_net: flows.Flow, log_transform_x: bool = False, ): """Initialize class for combining density estimators for MNLE. diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 497532d20..c0d672bd0 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -262,18 +262,19 @@ def accept_reject_sample( candidates = proposal.sample( (sampling_batch_size,), **proposal_sampling_kwargs # type: ignore ) # type: ignore - + print(candidates.shape) # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - + print(are_accepted.shape) samples = candidates[are_accepted] - + print(samples.shape) + print(are_accepted.shape) accepted.append(samples) # Update. num_sampled_total += sampling_batch_size - num_remaining -= samples.shape[0] - pbar.update(samples.shape[0]) + num_remaining -= samples.shape[-2] + pbar.update(samples.shape[-2]) # To avoid endless sampling when leakage is high, we raise a warning if the # acceptance rate is too low after the first 1_000 samples. diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index c2e8a51d0..4a7c93e29 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -553,15 +553,16 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: # Try to check using the support property, use log prob method otherwise. try: sample_check = distribution.support.check(samples) - + return sample_check # Before torch v1.7.0, `support.check()` returned bools for every element. # From v1.8.0 on, it directly considers all dimensions of a sample. E.g., # for a single sample in 3D, v1.7.0 would return [[True, True, True]] and # v1.8.0 would return [True]. - if sample_check.ndim > 1: - return torch.all(sample_check, dim=1) - else: - return sample_check + # TODO Check + # if sample_check.ndim > 1: + # return torch.all(sample_check, dim=-1) + # else: + # return sample_check # Falling back to log prob method of either the NeuralPosterior's net, or of a # custom wrapper distribution's. except (NotImplementedError, AttributeError): From 0f8541dd91f3fe4dda85d005ae8abd13db58ecfb Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 12:09:20 +0100 Subject: [PATCH 09/35] Fix mnle bug context -> condition --- sbi/neural_nets/mnle.py | 4 +- tutorials/00_getting_started.ipynb | 366 +++++++++++++++++++++++++++-- 2 files changed, 350 insertions(+), 20 deletions(-) diff --git a/sbi/neural_nets/mnle.py b/sbi/neural_nets/mnle.py index 833ce921c..2680d80a9 100644 --- a/sbi/neural_nets/mnle.py +++ b/sbi/neural_nets/mnle.py @@ -312,7 +312,7 @@ def log_prob(self, x: Tensor, context: Tensor) -> Tensor: # Transform to log-space if needed. torch.log(cont_x) if self.log_transform_x else cont_x, # Pass parameters and discrete x as context. - context=torch.cat((context, disc_x), dim=1), + condition=torch.cat((context, disc_x), dim=1), ) # Combine into joint lp. @@ -385,7 +385,7 @@ def log_prob_iid(self, x: Tensor, theta: Tensor) -> Tensor: # Get repeat discrete data and theta to match in batch shape for flow eval. log_probs_cont = self.continuous_net.log_prob( torch.log(x_cont_repeated) if self.log_transform_x else x_cont_repeated, - context=torch.cat((theta_repeated, x_disc_repeated), dim=1), + condition=torch.cat((theta_repeated, x_disc_repeated), dim=1), ) # Combine into joint lp with first dim over trials. diff --git a/tutorials/00_getting_started.ipynb b/tutorials/00_getting_started.ipynb index 6ccb72d79..f405a3c3e 100644 --- a/tutorials/00_getting_started.ipynb +++ b/tutorials/00_getting_started.ipynb @@ -26,6 +26,318 @@ "from sbi.inference.base import infer" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import pytest\n", + "import torch\n", + "from torch import eye, ones, zeros\n", + "from torch.distributions import MultivariateNormal\n", + "\n", + "from sbi import utils\n", + "from sbi.inference import SNLE, SNPE, SNRE\n", + "from sbi.neural_nets.embedding_nets import (\n", + " CNNEmbedding,\n", + " FCEmbedding,\n", + " PermutationInvariantEmbedding,\n", + ")\n", + "from sbi.simulators.linear_gaussian import (\n", + " linear_gaussian,\n", + " true_posterior_linear_gaussian_mvn_prior,\n", + ")\n", + "from sbi.utils import classifier_nn, likelihood_nn, posterior_nn\n", + "\n", + "import pytest\n", + "from torch import eye, ones, zeros\n", + "from torch.distributions import MultivariateNormal\n", + "\n", + "from sbi.inference import (\n", + " SNPE_A,\n", + " SNPE_C,\n", + " DirectPosterior,\n", + " prepare_for_sbi,\n", + " simulate_for_sbi,\n", + ")\n", + "from sbi.simulators.linear_gaussian import diagonal_linear_gaussian\n", + "\n", + "from sbi.utils.metrics import c2st\n", + "\n", + "from sbi.inference import (\n", + " SNLE,\n", + " SNPE,\n", + " DirectPosterior,\n", + " MCMCPosterior,\n", + " likelihood_estimator_based_potential,\n", + " prepare_for_sbi,\n", + " simulate_for_sbi,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/root/sbi/sbi/utils/user_input_checks.py:209: UserWarning: Casting 1D Uniform prior to BoxUniform to match sbi batch requirements.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "801c6bc1be314eb590d3950a01088e75", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running 100 simulations.: 0%| | 0/100 [00:00 Tensor:\n", + " return linear_gaussian(theta, likelihood_shift, likelihood_cov)\n", + "\n", + "simulator, prior = prepare_for_sbi(simulator, prior)\n", + "inference = SNPE(density_estimator=\"mdn\")\n", + "\n", + "theta, x = simulate_for_sbi(simulator, prior, 100)\n", + "posterior_estimator = inference.append_simulations(theta, x).train()\n", + "posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior)\n", + "samples = posterior.sample((num_samples,), x=x_o)\n", + "log_probs = posterior.log_prob(samples, x=x_o)\n", + "\n", + "assert log_probs.shape == torch.Size([num_samples])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([100, 1])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([], size=(100, 0))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "samples" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.]])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_o" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 10, 32])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "posterior.posterior_estimator.sample((10,), condition=xo).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 32])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xo.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "15bfe30b17e944689124849428a21b1f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Drawing 1 posterior samples: 0%| | 0/1 [00:00" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Only 0.010% proposal samples are\n", + " accepted. It may take a long time to collect the remaining\n", + " 9999 samples. Consider interrupting (Ctrl-C) and switching to\n", + " `build_posterior(..., sample_with='mcmc')`.\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[43mposterior\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m10000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m log_probability \u001b[38;5;241m=\u001b[39m posterior\u001b[38;5;241m.\u001b[39mlog_prob(samples, x\u001b[38;5;241m=\u001b[39mobservation)\n\u001b[1;32m 3\u001b[0m _ \u001b[38;5;241m=\u001b[39m analysis\u001b[38;5;241m.\u001b[39mpairplot(samples, limits\u001b[38;5;241m=\u001b[39m[[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m2\u001b[39m], [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m2\u001b[39m], [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m2\u001b[39m]], figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m6\u001b[39m, \u001b[38;5;241m6\u001b[39m))\n", + "File \u001b[0;32m~/sbi/sbi/inference/posteriors/direct_posterior.py:118\u001b[0m, in \u001b[0;36mDirectPosterior.sample\u001b[0;34m(self, sample_shape, x, max_sampling_batch_size, sample_with, show_progress_bars)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sample_with \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 113\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou set `sample_with=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msample_with\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`. As of sbi v0.18.0, setting \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`sample_with` is no longer supported. You have to rerun \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`.build_posterior(sample_with=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msample_with\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m).`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m )\n\u001b[0;32m--> 118\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[43maccept_reject_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[43m \u001b[49m\u001b[43mproposal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mposterior_estimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_reject_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtheta\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mwithin_support\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprior\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtheta\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mshow_progress_bars\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshow_progress_bars\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_sampling_batch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_sampling_batch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 124\u001b[0m \u001b[43m \u001b[49m\u001b[43mproposal_sampling_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcondition\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 125\u001b[0m \u001b[43m \u001b[49m\u001b[43malternative_method\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbuild_posterior(..., sample_with=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmcmc\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 126\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m samples\n", + "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/sbi/sbi/samplers/rejection/rejection.py:262\u001b[0m, in \u001b[0;36maccept_reject_sample\u001b[0;34m(proposal, accept_reject_fn, num_samples, show_progress_bars, warn_acceptance, sample_for_correction_factor, max_sampling_batch_size, proposal_sampling_kwargs, alternative_method, **kwargs)\u001b[0m\n\u001b[1;32m 259\u001b[0m sampling_batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(num_samples, max_sampling_batch_size)\n\u001b[1;32m 260\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m num_remaining \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# Sample and reject.\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m candidates \u001b[38;5;241m=\u001b[39m \u001b[43mproposal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43msampling_batch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mproposal_sampling_kwargs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore\u001b[39;49;00m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# SNPE-style rejection-sampling when the proposal is the neural net.\u001b[39;00m\n\u001b[1;32m 267\u001b[0m are_accepted \u001b[38;5;241m=\u001b[39m accept_reject_fn(candidates)\n", + "File \u001b[0;32m~/sbi/sbi/neural_nets/density_estimators/nflows_flow.py:110\u001b[0m, in \u001b[0;36mNFlowsFlow.sample\u001b[0;34m(self, sample_shape, condition)\u001b[0m\n\u001b[1;32m 108\u001b[0m batch_shape \u001b[38;5;241m=\u001b[39m condition\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m-\u001b[39mcondition_dims]\n\u001b[1;32m 109\u001b[0m condition \u001b[38;5;241m=\u001b[39m condition\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_condition_shape)\n\u001b[0;32m--> 110\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcondition\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mreshape(\n\u001b[1;32m 111\u001b[0m (\u001b[38;5;241m*\u001b[39mbatch_shape, \u001b[38;5;241m*\u001b[39msample_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 112\u001b[0m )\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m samples\n", + "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/distributions/base.py:65\u001b[0m, in \u001b[0;36mDistribution.sample\u001b[0;34m(self, num_samples, context, batch_size)\u001b[0m\n\u001b[1;32m 62\u001b[0m context \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mas_tensor(context)\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 65\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m check\u001b[38;5;241m.\u001b[39mis_positive_int(batch_size):\n", + "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/flows/base.py:54\u001b[0m, in \u001b[0;36mFlow._sample\u001b[0;34m(self, num_samples, context)\u001b[0m\n\u001b[1;32m 49\u001b[0m noise \u001b[38;5;241m=\u001b[39m torchutils\u001b[38;5;241m.\u001b[39mmerge_leading_dims(noise, num_dims\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 50\u001b[0m embedded_context \u001b[38;5;241m=\u001b[39m torchutils\u001b[38;5;241m.\u001b[39mrepeat_rows(\n\u001b[1;32m 51\u001b[0m embedded_context, num_reps\u001b[38;5;241m=\u001b[39mnum_samples\n\u001b[1;32m 52\u001b[0m )\n\u001b[0;32m---> 54\u001b[0m samples, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_transform\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minverse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoise\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membedded_context\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m embedded_context \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 57\u001b[0m \u001b[38;5;66;03m# Split the context dimension from sample dimension.\u001b[39;00m\n\u001b[1;32m 58\u001b[0m samples \u001b[38;5;241m=\u001b[39m torchutils\u001b[38;5;241m.\u001b[39msplit_leading_dim(samples, shape\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, num_samples])\n", + "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/transforms/base.py:60\u001b[0m, in \u001b[0;36mCompositeTransform.inverse\u001b[0;34m(self, inputs, context)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minverse\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs, context\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 59\u001b[0m funcs \u001b[38;5;241m=\u001b[39m (transform\u001b[38;5;241m.\u001b[39minverse \u001b[38;5;28;01mfor\u001b[39;00m transform \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_transforms[::\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[0;32m---> 60\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cascade\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfuncs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/transforms/base.py:50\u001b[0m, in \u001b[0;36mCompositeTransform._cascade\u001b[0;34m(inputs, funcs, context)\u001b[0m\n\u001b[1;32m 48\u001b[0m total_logabsdet \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mnew_zeros(batch_size)\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m func \u001b[38;5;129;01min\u001b[39;00m funcs:\n\u001b[0;32m---> 50\u001b[0m outputs, logabsdet \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 51\u001b[0m total_logabsdet \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m logabsdet\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs, total_logabsdet\n", + "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/transforms/permutations.py:44\u001b[0m, in \u001b[0;36mPermutation.inverse\u001b[0;34m(self, inputs, context)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs, context\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_permute(inputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_permutation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dim)\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minverse\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs, context\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_permute(inputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inverse_permutation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dim)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] } ], "source": [ @@ -186,7 +516,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.11.7" } }, "nbformat": 4, From 593406203f299690d62aadb64d4bf7ef14406aab Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 12:20:15 +0100 Subject: [PATCH 10/35] Temp fix of SNPE_A and C ... Need to normalize interface MOG... --- sbi/inference/snpe/snpe_a.py | 24 ++++++++++++++---------- sbi/inference/snpe/snpe_c.py | 16 +++++++++------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 3a29ffbe2..32f800e25 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -340,10 +340,10 @@ def _expand_mog(self, eps: float = 1e-5): Args: eps: Standard deviation for the random perturbation. """ - assert isinstance(self._neural_net._distribution, MultivariateGaussianMDN) + assert isinstance(self._neural_net.net._distribution, MultivariateGaussianMDN) # Increase the number of components - self._neural_net._distribution._num_components = self._num_components + self._neural_net.net._distribution._num_components = self._num_components # Expand the 1-dim Gaussian. for name, param in self._neural_net.named_parameters(): @@ -492,7 +492,7 @@ def _sample_approx_posterior_mog( num_samples, logits_p, m_p, prec_factors_p ) - embedded_context = self._neural_net._embedding_net(x) + embedded_context = self._neural_net.net._embedding_net(x) if embedded_context is not None: # Merge the context dimension with sample dimension in order to # apply the transform. @@ -501,7 +501,9 @@ def _sample_approx_posterior_mog( embedded_context, num_reps=num_samples ) # TODO - theta, _ = self._neural_net._transform.inverse(theta, context=embedded_context) + theta, _ = self._neural_net.net._transform.inverse( + theta, context=embedded_context + ) if embedded_context is not None: # Split the context dimension from sample dimension. @@ -522,8 +524,8 @@ def _posthoc_correction(self, x: Tensor): """ # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. + encoded_x = self._neural_net.net._embedding_net(x) + dist = self._neural_net.net._distribution # defined to avoid black formatting. logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) @@ -617,7 +619,9 @@ def _set_state_for_mog_proposal(self) -> None: training step if the prior is Gaussian. """ - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + self.z_score_theta = isinstance( + self._neural_net.net._transform, CompositeTransform + ) self._set_maybe_z_scored_prior() @@ -649,8 +653,8 @@ def _set_maybe_z_scored_prior(self) -> None: prior will not be exactly have mean=0 and std=1. """ if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift + scale = self._neural_net.net._transform._transforms[0]._scale + shift = self._neural_net.net._transform._transforms[0]._shift # Following the definition of the linear transform in # `standardizing_transform` in `sbiutils.py`: @@ -684,7 +688,7 @@ def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: """Return potentially standardized theta if z-scoring was requested.""" if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) + theta, _ = self._neural_net.net._transform(theta) return theta diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index 7f488efca..deaa73f1d 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -164,7 +164,7 @@ def train( self.use_non_atomic_loss = ( isinstance(proposal, DirectPosterior) and isinstance(proposal.posterior_estimator._distribution, mdn) - and isinstance(self._neural_net._distribution, mdn) + and isinstance(self._neural_net.net._distribution, mdn) and check_dist_class( self._prior, class_to_check=(Uniform, MultivariateNormal) )[0] @@ -191,7 +191,9 @@ def _set_state_for_mog_proposal(self) -> None: training step if the prior is Gaussian. """ - self.z_score_theta = isinstance(self._neural_net._transform, CompositeTransform) + self.z_score_theta = isinstance( + self._neural_net.net._transform, CompositeTransform + ) self._set_maybe_z_scored_prior() @@ -221,8 +223,8 @@ def _set_maybe_z_scored_prior(self) -> None: """ if self.z_score_theta: - scale = self._neural_net._transform._transforms[0]._scale - shift = self._neural_net._transform._transforms[0]._shift + scale = self._neural_net.net._transform._transforms[0]._scale + shift = self._neural_net.net._transform._transforms[0]._shift # Following the definintion of the linear transform in # `standardizing_transform` in `sbiutils.py`: @@ -396,8 +398,8 @@ def _log_prob_proposal_posterior_mog( norm_logits_p = logits_p - torch.logsumexp(logits_p, dim=-1, keepdim=True) # Evaluate the density estimator. - encoded_x = self._neural_net._embedding_net(x) - dist = self._neural_net._distribution # defined to avoid black formatting. + encoded_x = self._neural_net.net._embedding_net(x) + dist = self._neural_net.net._distribution # defined to avoid black formatting. logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) @@ -632,6 +634,6 @@ def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: """Return potentially standardized theta if z-scoring was requested.""" if self.z_score_theta: - theta, _ = self._neural_net._transform(theta) + theta, _ = self._neural_net.net._transform(theta) return theta From 73fa99ddd6df40fd7a675d1c3309ffdfd4870359 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 13:10:44 +0100 Subject: [PATCH 11/35] Fix multi round bugs --- sbi/inference/snpe/snpe_a.py | 31 +++++++++++++++++++------------ sbi/inference/snpe/snpe_c.py | 6 +++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 32f800e25..5f8641da5 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -17,6 +17,7 @@ import sbi.utils as utils from sbi.inference.posteriors.direct_posterior import DirectPosterior from sbi.inference.snpe.snpe_base import PosteriorEstimator +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import TensorboardSummaryWriter, TorchModule from sbi.utils import torchutils @@ -110,7 +111,7 @@ def train( show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, component_perturbation: float = 5e-3, - ) -> nn.Module: + ) -> DensityEstimator: r"""Return density estimator that approximates the proposal posterior. [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional @@ -360,7 +361,7 @@ def _expand_mog(self, eps: float = 1e-5): param.grad = None # let autograd construct a new gradient -class SNPE_A_MDN(nn.Module): +class SNPE_A_MDN(DensityEstimator): """Generates a posthoc-corrected MDN which approximates the posterior. This class takes as input the density estimator (abbreviated with `_d` suffix, aka @@ -379,7 +380,7 @@ class SNPE_A_MDN(nn.Module): def __init__( self, - flow: flows.Flow, + flow: DensityEstimator, proposal: Union["utils.BoxUniform", "MultivariateNormal", "DirectPosterior"], prior: Distribution, device: str, @@ -392,7 +393,8 @@ def __init__( prior: The prior distribution. """ # Call nn.Module's constructor. - super().__init__() + + super().__init__(flow, flow._condition_shape) self._neural_net = flow self._prior = prior @@ -419,18 +421,19 @@ def __init__( # Take care of z-scoring, pre-compute and store prior terms. self._set_state_for_mog_proposal() - def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: - inputs, context = inputs.to(self._device), context.to(self._device) + def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor: + + inputs, condition = inputs.to(self._device), condition.to(self._device) if not self._apply_correction: - return self._neural_net.log_prob(inputs, context) + return self._neural_net.log_prob(inputs, condition) else: # When we want to compute the approx. posterior, a proposal prior \tilde{p} # has already been observed. To analytically calculate the log-prob of the # Gaussian, we first need to compute the mixture components. # Compute the mixture components of the proposal posterior. - logits_pp, m_pp, prec_pp = self._posthoc_correction(context) + logits_pp, m_pp, prec_pp = self._posthoc_correction(condition) # z-score theta if it z-scoring had been requested. theta = self._maybe_z_score_theta(inputs) @@ -444,16 +447,20 @@ def log_prob(self, inputs: Tensor, context: Tensor) -> Tensor: ) return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] - def sample(self, num_samples: int, context: Tensor, batch_size: int = 1) -> Tensor: - context = context.to(self._device) + def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: + condition = condition.to(self._device) if not self._apply_correction: - return self._neural_net.sample(num_samples, context, batch_size) + return self._neural_net.sample(sample_shape, condition=condition) else: # When we want to sample from the approx. posterior, a proposal prior # \tilde{p} has already been observed. To analytically calculate the # log-prob of the Gaussian, we first need to compute the mixture components. - return self._sample_approx_posterior_mog(num_samples, context, batch_size) + num_samples = torch.Size(sample_shape).numel() + condition_ndim = len(self._condition_shape) + batch_size = condition.shape[:-condition_ndim] + batch_size = torch.Size(batch_size).numel() + return self._sample_approx_posterior_mog(num_samples, condition, batch_size) def _sample_approx_posterior_mog( self, num_samples, x: Tensor, batch_size: int diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index deaa73f1d..90145311e 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -163,7 +163,7 @@ def train( proposal = self._proposal_roundwise[-1] self.use_non_atomic_loss = ( isinstance(proposal, DirectPosterior) - and isinstance(proposal.posterior_estimator._distribution, mdn) + and isinstance(proposal.posterior_estimator.net._distribution, mdn) and isinstance(self._neural_net.net._distribution, mdn) and check_dist_class( self._prior, class_to_check=(Uniform, MultivariateNormal) @@ -390,9 +390,9 @@ def _log_prob_proposal_posterior_mog( # Evaluate the proposal. MDNs do not have functionality to run the embedding_net # and then get the mixture_components (**without** calling log_prob()). Hence, # we call them separately here. - encoded_x = proposal.posterior_estimator._embedding_net(proposal.default_x) + encoded_x = proposal.posterior_estimator.net._embedding_net(proposal.default_x) dist = ( - proposal.posterior_estimator._distribution + proposal.posterior_estimator.net._distribution ) # defined to avoid ugly black formatting. logits_p, m_p, prec_p, _, _ = dist.get_mixture_components(encoded_x) norm_logits_p = logits_p - torch.logsumexp(logits_p, dim=-1, keepdim=True) From b1cb88e2094d7ace46aec4879ec5dc6e3af7ffa8 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 15:46:38 +0100 Subject: [PATCH 12/35] Fixing hopefully last tests that fail --- sbi/analysis/conditional_density.py | 5 +++-- sbi/utils/conditional_density_utils.py | 2 +- sbi/utils/user_input_checks_utils.py | 2 +- tests/linearGaussian_snpe_test.py | 1 + 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 293271df5..c1afd67f8 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -10,6 +10,7 @@ from torch import Tensor, nn from torch.distributions import Distribution +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape, TorchTransform from sbi.utils.conditional_density_utils import ( ConditionedPotential, @@ -187,7 +188,7 @@ def conditional_corrcoeff( class ConditionedMDN: def __init__( self, - net: nn.Module, + net: DensityEstimator, # TODO: Must be MDN!!! x_o: Tensor, condition: Tensor, dims_to_sample: List[int], @@ -207,7 +208,7 @@ def __init__( """ condition = atleast_2d_float32_tensor(condition) - logits, means, precfs, _ = extract_and_transform_mog(net=net, context=x_o) + logits, means, precfs, _ = extract_and_transform_mog(net=net.net, context=x_o) self.logits, self.means, self.precfs, self.sumlogdiag = condition_mog( condition, dims_to_sample, logits, means, precfs ) diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index 69c4b08f6..3bbbf71e7 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -294,7 +294,7 @@ def __init__( self.device = self.potential_fn.device self.allow_iid_x = allow_iid_x - def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: + def __call__(self, theta: Tensor, track_gradients: bool = True, x_o=None) -> Tensor: r""" Returns the conditional potential $\log(p(\theta_i|\theta_j, x))$. diff --git a/sbi/utils/user_input_checks_utils.py b/sbi/utils/user_input_checks_utils.py index d36ceeb4e..499cfec14 100644 --- a/sbi/utils/user_input_checks_utils.py +++ b/sbi/utils/user_input_checks_utils.py @@ -365,7 +365,7 @@ def support(self): # Wrap as `independent` in order to have the correct shape of the # `log_abs_det`, i.e. summed over the parameter dimensions. return constraints.independent( - constraints.cat(supports, dim=1, lengths=self.dims_per_dist), + constraints.cat(supports, dim=-1, lengths=self.dims_per_dist), reinterpreted_batch_ndims=1, ) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index d4b4151eb..001fd1a00 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -573,6 +573,7 @@ def simulator(theta): proposal=restricted_prior, **mcmc_parameters, ) + mcmc_posterior.set_default_x(x_o) # TODO: This test has a bug? Needed to add this cond_samples = mcmc_posterior.sample((num_conditional_samples,)) _ = analysis.pairplot( From 4fbb3fff7016a8f01bf490a0cf07fc42ac47e0eb Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Mon, 11 Mar 2024 16:37:56 +0100 Subject: [PATCH 13/35] add embedding net to ZukoFlow --- .../density_estimators/zuko_flow.py | 25 +++++++++++++------ tests/density_estimator_test.py | 6 ++--- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index bef9adc13..b5cb36781 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -1,8 +1,8 @@ from typing import Tuple import torch -from torch import Tensor -from zuko.flows import Flow +from torch import nn, Tensor +from zuko.flows import Flow,LazyComposedTransform from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape @@ -15,7 +15,7 @@ class ZukoFlow(DensityEstimator): wrap them and add the .loss() method. """ - def __init__(self, net: Flow, condition_shape: torch.Size): + def __init__(self, net: Flow, embedding_net: nn.Module, condition_shape: torch.Size): r"""Initialize the density estimator. Args: @@ -23,8 +23,12 @@ def __init__(self, net: Flow, condition_shape: torch.Size): condition_shape: Shape of the condition. """ - assert len(condition_shape) == 1, "Zuko Flows require 1D conditions." + # assert len(condition_shape) == 1, "Zuko Flows require 1D conditions." super().__init__(net=net, condition_shape=condition_shape) + self._embedding_net = embedding_net + + def _maybe_z_score(self)->bool: + return True def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple @@ -64,9 +68,10 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: batch_shape = torch.broadcast_shapes(batch_shape_in, batch_shape_cond) # Expand the input and condition to the same batch shape input = input.expand(batch_shape + (input.shape[-1],)) - condition = condition.expand(batch_shape + self._condition_shape) + emb_cond = self._embedding_net(condition) + emb_cond = emb_cond.expand(batch_shape + (emb_cond.shape[-1],)) - dists = self.net(condition) + dists = self.net(emb_cond) log_probs = dists.log_prob(input) return log_probs @@ -106,7 +111,8 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: condition_dims = len(self._condition_shape) batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () - dists = self.net(condition) + emb_cond = self._embedding_net(condition) + dists = self.net(emb_cond) # zuko.sample() returns (*sample_shape, *batch_shape, input_size). samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1) @@ -124,10 +130,13 @@ def sample_and_log_prob( Returns: Samples and associated log probabilities. """ + self._check_condition_shape(condition) + condition_dims = len(self._condition_shape) batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () - dists = self.net(condition) + emb_cond = self._embedding_net(condition) + dists = self.net(emb_cond) samples, log_probs = dists.rsample_and_log_prob(sample_shape) # zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...). diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index d4d5b976b..a48fc28c6 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -54,8 +54,8 @@ def forward(self, x): embedding_net=EmbeddingNet(), ) elif density_estimator == ZukoFlow: - if len(condition_shape) > 1: - pytest.skip("ZukoFlow does not support multi-dimensional contexts.") + # if len(condition_shape) > 1: + # pytest.skip("ZukoFlow does not support multi-dimensional contexts.") net = NSF( features=input_dims, context=condition_shape[-1], @@ -63,7 +63,7 @@ def forward(self, x): hidden_features=(10,), bins=8, ) - estimator = density_estimator(net, condition_shape) + estimator = density_estimator(net,embedding_net=EmbeddingNet(), condition_shape=condition_shape) # Loss is only required to work for batched inputs and contexts loss = estimator.loss(batch_input, batch_context) From 5f25941145302438003ed913d6b789bca2230e5c Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Mon, 11 Mar 2024 16:38:40 +0100 Subject: [PATCH 14/35] add zuko_maf_builder and test for SNLE --- sbi/neural_nets/flow.py | 77 ++++++++++++++++++++++++++++++- sbi/utils/get_nn_models.py | 6 ++- sbi/utils/sbiutils.py | 13 ++++-- tests/linearGaussian_snle_test.py | 5 +- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index c53b2952e..8ac2994df 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -3,7 +3,7 @@ from functools import partial -from typing import Optional +from typing import Optional, Union, Sequence from warnings import warn from pyknos.nflows import distributions as distributions_ @@ -12,7 +12,7 @@ from pyknos.nflows.transforms.splines import rational_quadratic from torch import Tensor, nn, relu, tanh, tensor, uint8 -from sbi.neural_nets.density_estimators import NFlowsFlow +from sbi.neural_nets.density_estimators import NFlowsFlow,ZukoFlow from sbi.utils.sbiutils import ( standardizing_net, standardizing_transform, @@ -20,6 +20,7 @@ ) from sbi.utils.torchutils import create_alternating_binary_mask from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device +import zuko def build_made( @@ -415,6 +416,78 @@ def mask_in_layer(i): return flow +def build_zuko_maf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int],int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + residual: bool=True, + randperm: bool=False, + **kwargs, +) -> ZukoFlow: + """Builds MAF p(x|y). + + Args: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: Number of hidden features. + num_transforms: Number of transforms. + embedding_net: Optional embedding network for y. + residual: whether to use residual blocks in the coupling layer. + randperm: Whether features are randomly permuted between transformations or not. + kwargs: Additional arguments that are passed by the build function but are not + relevant for maf and are therefore ignored. + + Returns: + Neural network. + """ + x_numel = batch_x[0].numel() + # Infer the output dimensionality of the embedding_net by making a forward pass. + check_data_device(batch_x, batch_y) + check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) + y_numel = embedding_net(batch_y[:1]).numel() + if x_numel == 1: + warn("In one-dimensional output space, this flow is limited to Gaussians") + + + if isinstance(hidden_features,int): + hidden_features = [hidden_features]*num_transforms + if x_numel == 1: + maf = zuko.flows.MAF(features = x_numel, context=y_numel, hidden_features = hidden_features, transforms = num_transforms) + else: + maf = zuko.flows.MAF(features = x_numel, context=y_numel, hidden_features = hidden_features, transforms = num_transforms, randperm= randperm,residual=residual) + + transforms = maf.transform.transforms + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + #transforms = transforms + transforms = (*transforms,standardizing_transform(batch_x, structured_x,backend="zuko")) + + z_score_y_bool, structured_y = z_score_parser(z_score_y) + if z_score_y_bool: + # Prepend standardizing transform to y-embedding. + embedding_net = nn.Sequential( + standardizing_net(batch_y, structured_y), embedding_net + ) + + # Combine transforms. + neural_net = zuko.flows.Flow(transforms,maf.base) + + flow = ZukoFlow(neural_net, embedding_net, condition_shape=batch_y[0].shape) + + return flow + class ContextSplineMap(nn.Module): """ diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index 71938df37..b732fa2c7 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -11,7 +11,7 @@ build_mlp_classifier, build_resnet_classifier, ) -from sbi.neural_nets.flow import build_made, build_maf, build_maf_rqs, build_nsf +from sbi.neural_nets.flow import build_made, build_maf, build_maf_rqs, build_nsf,build_zuko_maf from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle @@ -168,6 +168,10 @@ def build_fn(batch_theta, batch_x): return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "mnle": return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs) + elif model == "zuko_maf": + del kwargs["num_bins"] + del kwargs["num_components"] + return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) else: raise NotImplementedError diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 4a7c93e29..bfaf8ea20 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -9,6 +9,7 @@ import numpy as np import pyknos.nflows.transforms as transforms +import zuko import torch import torch.distributions.transforms as torch_tf from pyro.distributions import Empirical @@ -134,8 +135,8 @@ def z_score_parser(z_score_flag: Optional["str"]) -> Tuple[bool, bool]: def standardizing_transform( - batch_t: Tensor, structured_dims: bool = False, min_std: float = 1e-14 -) -> transforms.AffineTransform: + batch_t: Tensor, structured_dims: bool = False, min_std: float = 1e-14, backend: str = "nflows" +) -> Union[transforms.AffineTransform,zuko.transforms.MonotonicAffineTransform]: """Builds standardizing transform Args: @@ -169,7 +170,13 @@ def standardizing_transform( t_std = torch.std(batch_t[is_valid_t], dim=0) t_std[t_std < min_std] = min_std - return transforms.AffineTransform(shift=-t_mean / t_std, scale=1 / t_std) + if backend == "nflows": + return transforms.AffineTransform(shift=-t_mean / t_std, scale=1 / t_std) + elif backend == "zuko": + return zuko.flows.Unconditional(zuko.transforms.MonotonicAffineTransform,shift=-t_mean / t_std, scale=1 / t_std,buffer=True) + + else: + raise ValueError("Invalid backend. Use 'nflows' or 'zuko'.") class Standardize(nn.Module): diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 6457ec0d2..4ba69895d 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -139,7 +139,8 @@ def test_c2st_snl_on_linear_gaussian_different_dims(model_str="maf"): @pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) @pytest.mark.parametrize("prior_str", ("uniform", "gaussian")) -def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: str): +@pytest.mark.parametrize("model_str", ("maf", "zuko_maf")) +def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: str,model_str: str): """Test SNL on linear Gaussian, comparing to ground truth posterior via c2st. Args: @@ -167,7 +168,7 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s lambda theta: linear_gaussian(theta, likelihood_shift, likelihood_cov), prior, ) - density_estimator = likelihood_nn("maf", num_transforms=3) + density_estimator = likelihood_nn(model_str, num_transforms=3) inference = SNLE(density_estimator=density_estimator, show_progress_bars=False) theta, x = simulate_for_sbi( From eb2c3010654591c66996d2eef5eef527318a9864 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Mon, 11 Mar 2024 16:42:43 +0100 Subject: [PATCH 15/35] formatting --- .../density_estimators/zuko_flow.py | 10 +++-- sbi/neural_nets/flow.py | 43 +++++++++++++------ sbi/utils/get_nn_models.py | 8 +++- sbi/utils/sbiutils.py | 18 +++++--- tests/density_estimator_test.py | 4 +- tests/linearGaussian_snle_test.py | 4 +- 6 files changed, 61 insertions(+), 26 deletions(-) diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index b5cb36781..a1c93970e 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -1,8 +1,8 @@ from typing import Tuple import torch -from torch import nn, Tensor -from zuko.flows import Flow,LazyComposedTransform +from torch import Tensor, nn +from zuko.flows import Flow, LazyComposedTransform from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape @@ -15,7 +15,9 @@ class ZukoFlow(DensityEstimator): wrap them and add the .loss() method. """ - def __init__(self, net: Flow, embedding_net: nn.Module, condition_shape: torch.Size): + def __init__( + self, net: Flow, embedding_net: nn.Module, condition_shape: torch.Size + ): r"""Initialize the density estimator. Args: @@ -27,7 +29,7 @@ def __init__(self, net: Flow, embedding_net: nn.Module, condition_shape: torch.S super().__init__(net=net, condition_shape=condition_shape) self._embedding_net = embedding_net - def _maybe_z_score(self)->bool: + def _maybe_z_score(self) -> bool: return True def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 8ac2994df..477bf02f2 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -3,16 +3,17 @@ from functools import partial -from typing import Optional, Union, Sequence +from typing import Optional, Sequence, Union from warnings import warn +import zuko from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets from pyknos.nflows.transforms.splines import rational_quadratic from torch import Tensor, nn, relu, tanh, tensor, uint8 -from sbi.neural_nets.density_estimators import NFlowsFlow,ZukoFlow +from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow from sbi.utils.sbiutils import ( standardizing_net, standardizing_transform, @@ -20,7 +21,6 @@ ) from sbi.utils.torchutils import create_alternating_binary_mask from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device -import zuko def build_made( @@ -416,16 +416,17 @@ def mask_in_layer(i): return flow + def build_zuko_maf( batch_x: Tensor, batch_y: Tensor, z_score_x: Optional[str] = "independent", z_score_y: Optional[str] = "independent", - hidden_features: Union[Sequence[int],int] = 50, + hidden_features: Union[Sequence[int], int] = 50, num_transforms: int = 5, embedding_net: nn.Module = nn.Identity(), - residual: bool=True, - randperm: bool=False, + residual: bool = True, + randperm: bool = False, **kwargs, ) -> ZukoFlow: """Builds MAF p(x|y). @@ -460,19 +461,33 @@ def build_zuko_maf( if x_numel == 1: warn("In one-dimensional output space, this flow is limited to Gaussians") - - if isinstance(hidden_features,int): - hidden_features = [hidden_features]*num_transforms + if isinstance(hidden_features, int): + hidden_features = [hidden_features] * num_transforms if x_numel == 1: - maf = zuko.flows.MAF(features = x_numel, context=y_numel, hidden_features = hidden_features, transforms = num_transforms) + maf = zuko.flows.MAF( + features=x_numel, + context=y_numel, + hidden_features=hidden_features, + transforms=num_transforms, + ) else: - maf = zuko.flows.MAF(features = x_numel, context=y_numel, hidden_features = hidden_features, transforms = num_transforms, randperm= randperm,residual=residual) + maf = zuko.flows.MAF( + features=x_numel, + context=y_numel, + hidden_features=hidden_features, + transforms=num_transforms, + randperm=randperm, + residual=residual, + ) transforms = maf.transform.transforms z_score_x_bool, structured_x = z_score_parser(z_score_x) if z_score_x_bool: - #transforms = transforms - transforms = (*transforms,standardizing_transform(batch_x, structured_x,backend="zuko")) + # transforms = transforms + transforms = ( + *transforms, + standardizing_transform(batch_x, structured_x, backend="zuko"), + ) z_score_y_bool, structured_y = z_score_parser(z_score_y) if z_score_y_bool: @@ -482,7 +497,7 @@ def build_zuko_maf( ) # Combine transforms. - neural_net = zuko.flows.Flow(transforms,maf.base) + neural_net = zuko.flows.Flow(transforms, maf.base) flow = ZukoFlow(neural_net, embedding_net, condition_shape=batch_y[0].shape) diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index b732fa2c7..6ac836c01 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -11,7 +11,13 @@ build_mlp_classifier, build_resnet_classifier, ) -from sbi.neural_nets.flow import build_made, build_maf, build_maf_rqs, build_nsf,build_zuko_maf +from sbi.neural_nets.flow import ( + build_made, + build_maf, + build_maf_rqs, + build_nsf, + build_zuko_maf, +) from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index bfaf8ea20..e7fbb9f51 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -9,9 +9,9 @@ import numpy as np import pyknos.nflows.transforms as transforms -import zuko import torch import torch.distributions.transforms as torch_tf +import zuko from pyro.distributions import Empirical from torch import Tensor from torch import nn as nn @@ -135,8 +135,11 @@ def z_score_parser(z_score_flag: Optional["str"]) -> Tuple[bool, bool]: def standardizing_transform( - batch_t: Tensor, structured_dims: bool = False, min_std: float = 1e-14, backend: str = "nflows" -) -> Union[transforms.AffineTransform,zuko.transforms.MonotonicAffineTransform]: + batch_t: Tensor, + structured_dims: bool = False, + min_std: float = 1e-14, + backend: str = "nflows", +) -> Union[transforms.AffineTransform, zuko.transforms.MonotonicAffineTransform]: """Builds standardizing transform Args: @@ -173,8 +176,13 @@ def standardizing_transform( if backend == "nflows": return transforms.AffineTransform(shift=-t_mean / t_std, scale=1 / t_std) elif backend == "zuko": - return zuko.flows.Unconditional(zuko.transforms.MonotonicAffineTransform,shift=-t_mean / t_std, scale=1 / t_std,buffer=True) - + return zuko.flows.Unconditional( + zuko.transforms.MonotonicAffineTransform, + shift=-t_mean / t_std, + scale=1 / t_std, + buffer=True, + ) + else: raise ValueError("Invalid backend. Use 'nflows' or 'zuko'.") diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index a48fc28c6..f2e2980f8 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -63,7 +63,9 @@ def forward(self, x): hidden_features=(10,), bins=8, ) - estimator = density_estimator(net,embedding_net=EmbeddingNet(), condition_shape=condition_shape) + estimator = density_estimator( + net, embedding_net=EmbeddingNet(), condition_shape=condition_shape + ) # Loss is only required to work for batched inputs and contexts loss = estimator.loss(batch_input, batch_context) diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 4ba69895d..3846c5d6c 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -140,7 +140,9 @@ def test_c2st_snl_on_linear_gaussian_different_dims(model_str="maf"): @pytest.mark.parametrize("num_dim", (1, 2)) @pytest.mark.parametrize("prior_str", ("uniform", "gaussian")) @pytest.mark.parametrize("model_str", ("maf", "zuko_maf")) -def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: str,model_str: str): +def test_c2st_and_map_snl_on_linearGaussian_different( + num_dim: int, prior_str: str, model_str: str +): """Test SNL on linear Gaussian, comparing to ground truth posterior via c2st. Args: From a5e6fc3e1b56a4fcdaa6da2f69004a31c5bfba40 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 15:53:06 +0100 Subject: [PATCH 16/35] If else MAF possibly unbound fix --- sbi/neural_nets/flow.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 477bf02f2..4f81a4b9e 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -463,22 +463,23 @@ def build_zuko_maf( if isinstance(hidden_features, int): hidden_features = [hidden_features] * num_transforms - if x_numel == 1: - maf = zuko.flows.MAF( - features=x_numel, - context=y_numel, - hidden_features=hidden_features, - transforms=num_transforms, - ) - else: - maf = zuko.flows.MAF( - features=x_numel, - context=y_numel, - hidden_features=hidden_features, - transforms=num_transforms, - randperm=randperm, - residual=residual, - ) + + if x_numel == 1: + maf = zuko.flows.MAF( + features=x_numel, + context=y_numel, + hidden_features=hidden_features, + transforms=num_transforms, + ) + else: + maf = zuko.flows.MAF( + features=x_numel, + context=y_numel, + hidden_features=hidden_features, + transforms=num_transforms, + randperm=randperm, + residual=residual, + ) transforms = maf.transform.transforms z_score_x_bool, structured_x = z_score_parser(z_score_x) From 1ca7ab2f8deecda7847c32afc397ee6b1c69eb11 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 16:21:28 +0100 Subject: [PATCH 17/35] Reverting change to notebook :/ --- tutorials/00_getting_started.ipynb | 366 ++--------------------------- 1 file changed, 18 insertions(+), 348 deletions(-) diff --git a/tutorials/00_getting_started.ipynb b/tutorials/00_getting_started.ipynb index f405a3c3e..6ccb72d79 100644 --- a/tutorials/00_getting_started.ipynb +++ b/tutorials/00_getting_started.ipynb @@ -26,318 +26,6 @@ "from sbi.inference.base import infer" ] }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import annotations\n", - "\n", - "import pytest\n", - "import torch\n", - "from torch import eye, ones, zeros\n", - "from torch.distributions import MultivariateNormal\n", - "\n", - "from sbi import utils\n", - "from sbi.inference import SNLE, SNPE, SNRE\n", - "from sbi.neural_nets.embedding_nets import (\n", - " CNNEmbedding,\n", - " FCEmbedding,\n", - " PermutationInvariantEmbedding,\n", - ")\n", - "from sbi.simulators.linear_gaussian import (\n", - " linear_gaussian,\n", - " true_posterior_linear_gaussian_mvn_prior,\n", - ")\n", - "from sbi.utils import classifier_nn, likelihood_nn, posterior_nn\n", - "\n", - "import pytest\n", - "from torch import eye, ones, zeros\n", - "from torch.distributions import MultivariateNormal\n", - "\n", - "from sbi.inference import (\n", - " SNPE_A,\n", - " SNPE_C,\n", - " DirectPosterior,\n", - " prepare_for_sbi,\n", - " simulate_for_sbi,\n", - ")\n", - "from sbi.simulators.linear_gaussian import diagonal_linear_gaussian\n", - "\n", - "from sbi.utils.metrics import c2st\n", - "\n", - "from sbi.inference import (\n", - " SNLE,\n", - " SNPE,\n", - " DirectPosterior,\n", - " MCMCPosterior,\n", - " likelihood_estimator_based_potential,\n", - " prepare_for_sbi,\n", - " simulate_for_sbi,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/root/sbi/sbi/utils/user_input_checks.py:209: UserWarning: Casting 1D Uniform prior to BoxUniform to match sbi batch requirements.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "801c6bc1be314eb590d3950a01088e75", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Running 100 simulations.: 0%| | 0/100 [00:00 Tensor:\n", - " return linear_gaussian(theta, likelihood_shift, likelihood_cov)\n", - "\n", - "simulator, prior = prepare_for_sbi(simulator, prior)\n", - "inference = SNPE(density_estimator=\"mdn\")\n", - "\n", - "theta, x = simulate_for_sbi(simulator, prior, 100)\n", - "posterior_estimator = inference.append_simulations(theta, x).train()\n", - "posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior)\n", - "samples = posterior.sample((num_samples,), x=x_o)\n", - "log_probs = posterior.log_prob(samples, x=x_o)\n", - "\n", - "assert log_probs.shape == torch.Size([num_samples])" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([100, 1])" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([], size=(100, 0))" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "samples" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[1.]])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x_o" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 10, 32])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "posterior.posterior_estimator.sample((10,), condition=xo).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 32])" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "xo.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "15bfe30b17e944689124849428a21b1f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Drawing 1 posterior samples: 0%| | 0/1 [00:00 1\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[43mposterior\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m10000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mobservation\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m log_probability \u001b[38;5;241m=\u001b[39m posterior\u001b[38;5;241m.\u001b[39mlog_prob(samples, x\u001b[38;5;241m=\u001b[39mobservation)\n\u001b[1;32m 3\u001b[0m _ \u001b[38;5;241m=\u001b[39m analysis\u001b[38;5;241m.\u001b[39mpairplot(samples, limits\u001b[38;5;241m=\u001b[39m[[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m2\u001b[39m], [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m2\u001b[39m], [\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m2\u001b[39m]], figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m6\u001b[39m, \u001b[38;5;241m6\u001b[39m))\n", - "File \u001b[0;32m~/sbi/sbi/inference/posteriors/direct_posterior.py:118\u001b[0m, in \u001b[0;36mDirectPosterior.sample\u001b[0;34m(self, sample_shape, x, max_sampling_batch_size, sample_with, show_progress_bars)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sample_with \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 113\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou set `sample_with=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msample_with\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`. As of sbi v0.18.0, setting \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`sample_with` is no longer supported. You have to rerun \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`.build_posterior(sample_with=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msample_with\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m).`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 116\u001b[0m )\n\u001b[0;32m--> 118\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[43maccept_reject_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[43m \u001b[49m\u001b[43mproposal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mposterior_estimator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_reject_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtheta\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mwithin_support\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprior\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtheta\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mshow_progress_bars\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshow_progress_bars\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_sampling_batch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_sampling_batch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 124\u001b[0m \u001b[43m \u001b[49m\u001b[43mproposal_sampling_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcondition\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 125\u001b[0m \u001b[43m \u001b[49m\u001b[43malternative_method\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbuild_posterior(..., sample_with=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmcmc\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 126\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m samples\n", - "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/sbi/sbi/samplers/rejection/rejection.py:262\u001b[0m, in \u001b[0;36maccept_reject_sample\u001b[0;34m(proposal, accept_reject_fn, num_samples, show_progress_bars, warn_acceptance, sample_for_correction_factor, max_sampling_batch_size, proposal_sampling_kwargs, alternative_method, **kwargs)\u001b[0m\n\u001b[1;32m 259\u001b[0m sampling_batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(num_samples, max_sampling_batch_size)\n\u001b[1;32m 260\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m num_remaining \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# Sample and reject.\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m candidates \u001b[38;5;241m=\u001b[39m \u001b[43mproposal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43msampling_batch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mproposal_sampling_kwargs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore\u001b[39;49;00m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# SNPE-style rejection-sampling when the proposal is the neural net.\u001b[39;00m\n\u001b[1;32m 267\u001b[0m are_accepted \u001b[38;5;241m=\u001b[39m accept_reject_fn(candidates)\n", - "File \u001b[0;32m~/sbi/sbi/neural_nets/density_estimators/nflows_flow.py:110\u001b[0m, in \u001b[0;36mNFlowsFlow.sample\u001b[0;34m(self, sample_shape, condition)\u001b[0m\n\u001b[1;32m 108\u001b[0m batch_shape \u001b[38;5;241m=\u001b[39m condition\u001b[38;5;241m.\u001b[39mshape[:\u001b[38;5;241m-\u001b[39mcondition_dims]\n\u001b[1;32m 109\u001b[0m condition \u001b[38;5;241m=\u001b[39m condition\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_condition_shape)\n\u001b[0;32m--> 110\u001b[0m samples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcondition\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mreshape(\n\u001b[1;32m 111\u001b[0m (\u001b[38;5;241m*\u001b[39mbatch_shape, \u001b[38;5;241m*\u001b[39msample_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 112\u001b[0m )\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m samples\n", - "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/distributions/base.py:65\u001b[0m, in \u001b[0;36mDistribution.sample\u001b[0;34m(self, num_samples, context, batch_size)\u001b[0m\n\u001b[1;32m 62\u001b[0m context \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mas_tensor(context)\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 65\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m check\u001b[38;5;241m.\u001b[39mis_positive_int(batch_size):\n", - "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/flows/base.py:54\u001b[0m, in \u001b[0;36mFlow._sample\u001b[0;34m(self, num_samples, context)\u001b[0m\n\u001b[1;32m 49\u001b[0m noise \u001b[38;5;241m=\u001b[39m torchutils\u001b[38;5;241m.\u001b[39mmerge_leading_dims(noise, num_dims\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 50\u001b[0m embedded_context \u001b[38;5;241m=\u001b[39m torchutils\u001b[38;5;241m.\u001b[39mrepeat_rows(\n\u001b[1;32m 51\u001b[0m embedded_context, num_reps\u001b[38;5;241m=\u001b[39mnum_samples\n\u001b[1;32m 52\u001b[0m )\n\u001b[0;32m---> 54\u001b[0m samples, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_transform\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minverse\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoise\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membedded_context\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m embedded_context \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 57\u001b[0m \u001b[38;5;66;03m# Split the context dimension from sample dimension.\u001b[39;00m\n\u001b[1;32m 58\u001b[0m samples \u001b[38;5;241m=\u001b[39m torchutils\u001b[38;5;241m.\u001b[39msplit_leading_dim(samples, shape\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, num_samples])\n", - "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/transforms/base.py:60\u001b[0m, in \u001b[0;36mCompositeTransform.inverse\u001b[0;34m(self, inputs, context)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minverse\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs, context\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 59\u001b[0m funcs \u001b[38;5;241m=\u001b[39m (transform\u001b[38;5;241m.\u001b[39minverse \u001b[38;5;28;01mfor\u001b[39;00m transform \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_transforms[::\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[0;32m---> 60\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cascade\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfuncs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/transforms/base.py:50\u001b[0m, in \u001b[0;36mCompositeTransform._cascade\u001b[0;34m(inputs, funcs, context)\u001b[0m\n\u001b[1;32m 48\u001b[0m total_logabsdet \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mnew_zeros(batch_size)\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m func \u001b[38;5;129;01min\u001b[39;00m funcs:\n\u001b[0;32m---> 50\u001b[0m outputs, logabsdet \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 51\u001b[0m total_logabsdet \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m logabsdet\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs, total_logabsdet\n", - "File \u001b[0;32m~/miniconda3/envs/sbi3/lib/python3.11/site-packages/nflows/transforms/permutations.py:44\u001b[0m, in \u001b[0;36mPermutation.inverse\u001b[0;34m(self, inputs, context)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs, context\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_permute(inputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_permutation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dim)\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minverse\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs, context\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_permute(inputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inverse_permutation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dim)\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWUAAAGBCAYAAAC+UKAvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlSklEQVR4nO3dXYxk513n8e//OefUS/e03fY4xs7YwYtWG62UBUNCssqYCwaiDYlQIhEkXrRopdwgcUGk5WJEJJRFQszVwiVC2axzgRYWARLaEKGsJlISe5O18SZxbBMrBId4Eojf2tMv9XbO+e/Fc6q7eqZ7prunqs/T3b+P1HJVdXXN066uXz31P8/zP+buiIhIGkLbAxARkR0KZRGRhCiURUQSolAWEUmIQllEJCEKZRGRhORH+BmtoUuHzeuB3hd+Uc9rIj5X//ncnlfQc5uSgzy3mimLiCREoSwikhCFsohIQhTKIiIJSSqUL165ysUrV9sehohIa46y+mJhrq0N2h6CiEirkpkpT2fIF1b7mi2LyJmVTChPZ8lPXL6kGbOInFnJhLKIiCiURUSSklQov3Tlg20PQUSkVUmF8pQO9onIWZVkKOtgn4icVUmGsojIWZVsKKuEISJnUbKhrBKGiJxFyYayiMhZpFAWEUmIQllEJCEKZRGRhCiURUQSolAWEUmIQllEJCEKZRGRhCQdyhdW+zxy+TPa2SciZ0br5+i7eOUq19YGXFjt3/S9Jy5fAuCRy5857mGJiLSi9VC+tjZQH2URkUbS5QsRkbNGoSwikhCFsohIQhTKIiIJUSiLiCREoSwikhCFsohIQhTKIiIJUSiLiCREoSwikhCFsohIQhTKIiIJUSiLiCREoSwikhCFsohIQhTKIiIJUSiLiCREoSwikpBWQ/nilat7nptPROSsavUcfTo/n4jIbipfiIgk5ESE8oXVPhevXG17GCIiC9daKB+mnvzE5UtcWxsseEQiIu1rraaserKIyM1ORPkCVMIQkbPhxISyShgichacmFAWETkLFMoiIglRKMtimMUvETkUhbIshoX4JWeP3pDviF41sjhe6wV6lkyf6+mbsZ77I2m194WcUmYxkG93n1nuixuPLN5sGMPO5dv9HchNFMoyH82L0sJM2GYZlufgjlcV1N7MngNM79fc5lUVryuc0zZ9M519rqcBPPucBsOyDOoar3aCOf4dVMc44JNHoSxHM31xuu8KZMtzvAlWyzKs140v0rKEqoK6eYGG6UfcGnfbCWxJ0+wnm5nnelsIzSckx6nic98p8KrGQhPCzXPvdbX770d2USjLnWkCOXSKODPudNh++QaD5oVrnQLKEh+Odv98CFhdQzC8DkATzHqxpmX6fFicAVsWoCgws/g8W4AsgDs2KeNtWRZny9bZ/dyHLL4B6zneUyuhrOb2p8B0hgzxY2xRYFnAup3d9zOLs6g8g3GASbk9kwbi7Dk+CBZqXJ9sk2dZ2ClNheZgXpZhIYay280lLAfch7GsQRafZz3Ze2ollNWM6PSwTmd75mS9Hn7PXXiW4b0cD3EW5c1sKowrwuYIm5TY1hAvSxiN4kdcJs1rVC/UJNx4ILapG4dOgZ1bxooCel28yKmXe5AZngdsUhG2xk05ymPJYjTBioJwbjm+KZclAK7Z8p5UvpCjsxA/nk6/ioJ6qUPdzSnPFXgw3AADD0aY1BSdjDCckFU1FgJeVRglXmd6gabCbjiAR/NpyJqSRa8HRU693Md7OZO7unhu1LmRjWqKZsZM5YRxCc2BPpvOot2xusZLtuvQsuNEhfK0U9wTly+1PZSza/boe5bFi1nAlpbwpR7VuQ6TczkbD8Q/rVBBXcBk2ciGztKrGflWQaeTkw0mhLVY0mBrACHgkziL0kfbBDQrZaYH9Wx5Cet08JUlqrv7bD7Ub950DQ9QZxDKjMm5DKvAaicb1hTXC8Jggr25EQO6OeBnWbaz6ka2nahQfuLyJR65/Jm2h3G2NbMny7KdJVBZhncLvFdQdTMmyxmje+L3sjFUXRivOvmmESaBojCsKmJ9cdCJB4vKEvMaC4bXmjklxcL2QVx6XbxbUC4XbN0X/xayMbiB52AVVN0Mq5xQQt4xrO6QA9nGzFLI6eOa6w34BicqlCUB09lTkccXaq+H9brUS13KlS5b9+cM7w1sPlzjRQ2Zx32jec14HBidzyjeDJx7uaCzkdMvAvnGmAxwgMEQM8dn/z19vD1+03XI0xU1nQL6PepzPYY/tMT4royqa5TLMLyvxnPHO46NjXwrkG0FOm9CsWnx2IJBWMuxqoYiHvijqS2rhLGbQlkOxx2aj6yW53G1RZHjnZyqlzE+Z4xXgHvGFJ2Sc0sj3I2yDpRlxuRczqjo0FnPcINiM8OqglDk8fGaj8LmcUa1vUxOL9rjZwEzi+WpoogH9XoFk5WMyVKgLqBcdsIDQ4pOSa8zYTDqMFzvUm5kWJ2BGfnAyLemxx7UE+V2FMpyOGa7d+2Z4UVOtVwwWcmZ3GVM7nbO37vB+aVN/s1dP6AbSgqreGW8wjfX7uf7djejtT7UTfBmBkUel81tP27QZpK2uAM1hLw5gBvLFtU9S4zu6bL5QxmTczC8z6lWS37k/tcpQkUeat4Y9vnBJKMERnVckVFsGuUgUC91CVW187xOl0HqjXcXhbIczl6znGBURaDqGFUXql7Nvf0tHly6zjuWr7EcRqyEAd8Zv4U3Jz3Wh13W+32KDvFovBGXU4UQ1zTPHvxRMLcrBCzP4htvL6dcCpTLcYZcrVQUKyMeXLpO0ezaK+vA60WFu1EtGdXQqIr4Rd48v9NjBsGgnglm1ZYBhbIcxfZSuLiLy7OMujMNZcdzZ3MSN5G8rXiNnk0orCTr1nAPDKuC54pVqo4xORewOiPPA57Hj7hW17hX8YCfZk/HZ6+lcGbNwb0OVTfE0lJTCg7nJqyuDHhk6TUKi4G6kg/JQ81311ZZX7sLKyFUTqh8Z4t9p8DK5lNSWeITbbGfpVCWo5muWzWLB/IMPMQvLM6YKjdWwxa97VfxFj/ceYXz3bfhmeM5VB2j7th274RYUw5NIOuFeuxubCoVDM8zvMiaA3ZGk7/kRcVyZ8x9xfp2KGdWs9nr8mpnmfUKrDKsdqwGq5o32axZk17V8fnWwdxdTlzFXWe1bsm0N27T0a0eDGM/gyKHEAgTJx85+cAIo2aplDmFVayECf+qqLk3DHmlvIvcasL5EeO7a6ou1Pl0aV2Is6g8b3YJZrFPgnryLtZNfZADlueEfi/u0ux2YulhKsQlcHUVWC7G/Ifl53nf8t/zk/1v82CxxhvjPoNxgVVGNoLOupNvVti4jH0xRiN8OMLH4+3dfbLjxIWyzmrdMm9msFUVvyxuo47rUh0r48fb2o3KjYnHg3c9ix/K1usetRvd7gQvHL8xb6frn8MerUBl4Wz6/z7EtcnkedxCnRnWTGanz1ldG8GcBzJ4SzBWwzjOlCddJmUWZ8c1ZGMnTOLMmKpp5dm0clWJ6mbHXr5QM6ITavaF08ykAGw4IgBZPyfvBYqNjKoXeHOjz4t2P/89e4zVYsDbuq/xRrnMi5v388JrDzB6aYX+q4Hu9ZpivSJsDGM/jMEwbr02w3FtJDkOM82lvHao4+zVpt+r4nPgGc1yRchGxuT1Di+EB/j40iUqN94YL3Ft426+9717sc2MzkagWIfOWkk2LPFujo2bFTYhxPJFsNi6dfpvyfGHspoRnRLTfsjjCZZlhElNmDjZ2MmGxmCY82bo82JxPyudEa/3ltksO7w2XObNjR7F9UCxCfmwJhvVcat1WcVtt7MvTtUbj5/XOyskiD0rtq83q+WsgmwQKNcL/t+rF6jqwGBcsLXZJazlZEMjG0I2dLJhSZjUKkMdkA70yeHNnilkOMLcCd2CjjvnuoHOeiBMekxWerz0YA/r1HyzP6GqAtUgJ7yZ092EfNPJt2rCqNr9gm3OVnFTQMviTNcmA5bHNqzU3nTyizsuO0VGGNVgBeMVw3PD6px/9vPbD5NtBHqvBDrXneV/rijWS7KNETacxE9Co3FchbHd8F7lixsplOVoao8Fw6rCyxIbxRlzsVFidUZVGNnYqHo5dccZD7N44Gdg5JuBfAj5yAnjmlDuscpCKy+O37SMMa0pA9SOlRUeSsJwQk7chVlngWxo4IZbtr36ptg0ig3oXHd6r47JhiU2qbDxBB8M4ici99hTW8/xnhTKciReToBip+44KQmbAzp1Tb7UAfoUW4FsFKjzQN0JWAn50MkHNd03a4qNks4rmzAaYxtbOhKfiuYALlnTQGgYe2Bn4wmh38UN8q2MzmZG2Q2M7zLwpqQxqcmHTrFekV8fxhny5gAfjfHxZGetMmjZ4z4UynJ404+bXuPu2PQcfO6xlAHkmx2sin2UPTh1YYTSyYcee+5eL8kGE2wwio3PJ5PdO/max5fjNy0pmDUlDJo1xc0B2GzYJZQ1YVyT9zNCFVdahEn8W7DSyTfLGMijcSxZTM/RCE35QoG8n2MNZa28OEXMYs13MMAtYGUZ65DdLlZWdNzxLKP7Stx4UPVzrPSdj7OjCTaawGCIl3F5nVd1nC1Xexzwk8WbretXFb6+vtOyM1g84WlVk01KyDO8U1AUOcXazinAbFITtkaxXLGxFX9mUjZnsY5v4grkWzvWUNbKi9PHyzJu8BiPt0+U6bVvdxfzYQZFjk06WFljw/gR1saTuGa1bE45X7vWribGp59+smxmcbLHHXhZhpUV5BlZXe9sQBk3B/TGk7hJZHpKqGkduQnk7aWOep5vciLLFzoDSQJm1rbidTxjSLWzisK3trAix/p9cCe8WseZdJ7HWdhk0qyBbWbI47FmyKnweAJbyzKonXrUnIW69tjGs9OBUMbgDYbNnKHcy5J61DyXzVlkvKp2TielGfJtHduOvnmWLrSrLxE+M9upZxoIVRU+GuHjCV7GAPbNTXxrELdmlxWMJ83HWu3uSs6N68SravsNc7pUcbvUNCnx4XDnazSGyaR5c22WNdbV9vV9/x3ZdmwzZZUuTrnpDHl6jr0sizXEzc14vY4vYt/cgqqinpldxR/UJpHW7QrjZo3ybE8MiEE7nTln2c4ZYqZ1YmtWbGy/ydY7j1fvHCCW/Z3I8oUkbPsFl8Ua4qTcOS9bVcW2nFUd68iSvplNJdhM9z4LGDvP4c7tjoW8eZOtb55173pc2YtCWebD4wkwt/sYTJc/WYCK7ROiumZJJ892gO7MevEqZqzZ9lnN4/ea4wN7ha6C+EAUyjJfN73wpttpg0oUJ90+QauVFPOlUJb5shs6fs3OsnZ1mlNnsOTd+Bzt95ypFDVXJ66f8pSa3Z8wCl+RAzmxoTxdo6xgTowfcFnbQe8n7bnxOdJzdiyOJZQXtb1a65VF5LRZeChPZ7KL2n2nMoaInCYLD+Vra4OFbofWbFlETpMTW1MWETmNFrYk7uKVq1xbGxxbq85Fl0lERI6DuY6miogkQ+ULEZGEKJRFRBKiUBYRSYhCWUQkIYdefWFm3wCGCxjLcboPeLXtQcxBz93f0fYgRGR+jrIkbuju75r7SI6RmT190n8HiL9H22MQkflS+UJEJCEKZRGRhBwllP947qM4fqfhd4DT83uISEM7+kREEqLyhYhIQhTKIiIJOVQom9mvmtnXzexZM3vSzH5sUQNbFDN7v5l908y+ZWaX2x7PYZnZw2b2eTN73syeM7PfbHtMIjI/h6opm9l7gRfc/Q0z+zngE+7+noWNbs7MLANeBN4HvAw8Bfyyuz/f6sAOwcweBB5092fMbAX4O+DDJ+l3EJH9HWqm7O5PuvsbzdUvAw/Nf0gL9W7gW+7+bXcfA38KfKjlMR2Ku3/f3Z9pLq8DLwAX2h2ViMzLndSUPwp8dl4DOSYXgO/OXH+ZExxoZvYI8OPAV1oeiojMyZHOPGJmP00M5cfmOxw5KDM7B/wF8DF3v972eERkPm47Uzaz3zCzrzZfbzWzHwU+CXzI3V9b/BDn6hrw8Mz1h5rbThQzK4iB/Cfu/pdtj0dE5uewB/reBlwFfs3dn1zYqBbEzHLigb6fIYbxU8CvuPtzrQ7sEMzMgE8Dr7v7x1oejojM2WFD+ZPALwDfaW4qT1q3NTP7APCHQAZ8yt1/r90RHY6ZPQZ8EXgWqJubf9vd/+YOH1pbO9Nh83yw94Vf1HObiM/Vf37b51bbrGVKfwjpUCifUgcJZe3oExFJiEJZRCQhCmURkYQolOXYXLxylYtXrrY9DJGkKZTlWEzD+NraoOWRiKTtVIWymX3CzH6rufy7Zvazd/BYnzKzHzRn75Y7dG1twBOXL7U9DJHknapQnuXuv+Pu//sOHuJx4P1zGs6ZdvHKVS6s9tsehsiJcOJD2cw+bmYvmtmXgLfP3P64mX2kufySmf1+s1X8aTP7CTP7WzP7BzP79b0e192/ALx+PL/F6aZZssjBHakhUSrM7J3ALwGPEn+XZ4j9hffyT+7+qJn9AXEWfBHoAd8A/mjhgxUROYATHcrATwF/5e5bAGb217e47/R7zwLnml7E62Y2MrNVd19b7FBFRG7vxJcvDmHU/LeeuTy9ftLfnETklDjpofwF4MNm1m9OjfTzbQ9IROROnOhQbk6L9GfA14hnQXlqXo9tZv8D+D/A283sZTP76LweW0RkP+oSJ1ML+UN45PJnAHjpyge3r19Y7Ws1xq2pS9wppS5xkoQb1yhrV5/I/hTKsnCzs+LpjFlE9qZQFhFJiEJZRCQhCmURkYQolEVEEqJQFhFJiEJZRCQhCmURkYQolOXYXVjt61x9IvtQKMvC7HfGkScuX9KuPpF9qGWlLMy1tYF28IkckmbKIiIJUSiLiCREoSwikhCFsohIQhTKIiIJUSiLiCREoSwikhCFsizEfhtHROTWtHlEFkIbR0SORjNlEZGEKJRFRBKiUBYRSYhCWUQkIQplaYV6KovsTaEsrVBPZZG9KZRFRBKiUBYRSYhCWUQkIQplEZGEKJRFRBKiUBYRSYhCWeZOHeJEjk5d4mTu1CFO5Og0UxYRSYhCWUQkIQplEZGEKJRFRBKiUBYRSYhCWVqj9p0iN1MoS2vUvlPkZgplEZGEKJRFZD7Mbr5+421yWwplEbm1g4Tr9Psz/7UsW+y4TiltsxaRg9kvmC3MXLTd1zPwslz0yE4VhbLMlZoRnULuzWw5hq2FW8yaLWBFDnWNu0N1TGM8RRTKMldqRnQKNTPk7TCemQkTDJudQWdZE8oO4zGeZRjgtcfvex1DXvalUBaR/dlMEO8VxiHE+4RmFp0FyHOo6hjQVYVbwDLAa7wO4NXO4yqgb6JQFpGbNQfqtme4AF5jeY71+5AFLM8hz/F+d9fPkQWYlITBCKoq1pTHE+rBMIZzluFVBbVqG3tRKIvI3poZrlcz4RkC1u3E2XCR452CeqUXw3g66zXDxvHgnlU1VsbZMqNR821rZs3H/QudDAplEdkxWx/2GiwQet04K+52sU6Bn1uCPKPuF1S9nPE9nfijtYNDKB3qDuHePtmgJHt9E3PHOp04c56Uu/8tlTB2USiLSLTXkrdgWKeAooMt9eLMeKmLF4G6nzNZyhmuxvXIoXKsgmw8nTFDkRnZepxVWxbwum5myR5r1Jou30ShLK2aNiV64vKltociU7vWGWexhtwp8F4X7xWUd3epehmj1YzhPcb6D4PnjueQbRm91wLZEDrXHaszqpUuITNCWcFwhJclZr67LCLbtKNPWqWmRAmzAFkWA7nI41cnp+plTJYDo7uM0T1G9eAIe+uQ/tvWKS+MGJ53xndB1YOyG6i7Gd4t8F4H6xRxp1+wW693PsM0UxaRHdMNIllG6Peg28U7Bd7vMLl3iaqXMbwnY3S3sfHDMLl/zHv+9Uus5CNWiy3+cfM8X+++la21Hh5yqq6RTToU3YxOXRPMsOEQm5T4eAwW8HIS/23VlgGFsohM3bAOmSLH8gzPM7yTU/Yzqn6g7BuTFWNy74SV85u88+7vcHc24Hy2wd35gNeGy1wDJmvLhMqYLBlWBfJujo0KQlHgAGXZ7BZUbXmWyhfSOjW7T0CzTG0ajnFjSBY3hVQ1VA4Byq4xPG+Mzjvn7t/kR+55nX/Xe5lHe9/hvb3v8a6lf+RHVl7jLasblKsVkxVnsmyUvekmkz3+bQXyLgplaZ3qyi27YWnadq03NN3h6hprvudZUyvuO3f3h/xQ/zpvydZ5IBtxf7bEA9l17u1sstIZQbei6kBdGJ6Bq4R8ICpfiJx101puyHa6vIUQywt5Dt0OngesdMIEwhjC2CjrwKAqeK1aZug5r1Qlz44e5h83z7M27IMbVoOVjk0nw3XTNa529Vreh0JZ5kYd4k62nUBuwrL22O2t6adsHjeIWAVWQVUHxnXOlncZVgUTz/ne+B7eHPcZjAuopj/TfDlxxu0eSxZBH9T3olCWuVGHuJMtrhuusLyIKZo1KzGGIwhgVWd7xhtKeH1tmdrhavffMqpz/mVwF68Pl/iXtRXG17t0v1fQuQ7FZk2xXpG/vokNx3ieQ1XBZBL/q1UXuyiURSTymfaa08NN7lhZQVnHWXIz8w2lUQ9y1vMeL22eZ1AW/GD9HKNxznizQ9jIKDYg33TykZMPK2w4hslOw3uv6t0NjwRQKIvITefWC7FB/dYg1pTdMXeKVwJh3KPs9hhvGHhB1c157vuPgINVRjaCc9eNbBB39HU2a3qvTsg3J3ieYZMSRiN8PNlZnyy7KJRlLlRPPiWma5Vrx5vThlhZwsiwIicUGZ2NDhAol4w6N+piJ9SzEXTX4gHBYlCTb9ZkW5PtrnG445MylkpUttiTQlnmQvXkE2xXONZ4GbvDmRsOcfddHQ/QBaBY75JvGb3XjboTGK+EWIIeOWHi5IOKUNaEUUXYGhPe3ISywrdnyGWsJatL3J4UyiKyYzsg41lCzBx3x6oKxmCjCdmw6ZVc1tTdHKsLrHTyrQqrasK4wsoaG03igb2tIZTlThhP68jaybcnhbIkQd3i0jPdROLjcTy1E+BbW+TfrSHP8E5BCIF8LYthPonbpm1SQlXjk0ksU5RlLFlMz2odLJ5QdXrmkdkG+aIdfZIG7epLzOxKjKraabM5KfH1DXxzCxtPsOEIu76JrW9hgxG2OcC3hvhgAIMhPhrHpvZVtWuWrFUX+9NMWUR2a2q9Xpbb5+qj9jhjhqYfRoWPxvFEqc33qaq4aqOeBvlMacLr7XXQ8frsuf8U0LMUyiKyP3e8dizUeEVTephpTm+GdToxjKsq7gB0jyHtN6xDVvgeiEJZRHa7MTzrKh6Pm13PbCGe7RrHB8P4Y1W1/4E7BfKBqaYsyVALz8Rt963wnTafTVnCqyoeuJu9z+zPyYEplCUZOth3grjvDuO6unln4I3hLAei8oXcMe3mO6MUuAuhmbLckWm5QeuLRSE9H5opyx3R9mqR+dJMWUQkIQplSYpWYMhZp1CWI1vEAb5pbVrBLGeVQlmOZJEH+KZL4xTMchbpQJ8c2sUrV4/lAJ/WLMtZpFCWW5rOVq+tDbiw2t/+76K9dOWDXLxylUcufwaItWYtu5OzwFxrC0VEkqGasohIQhTKIiIJUSiLiCREoSwikhCtvhAAzOwbwLDtcczBfcCrbQ/iDvXc/R1tD0LaoVCWqaG7v6vtQdwpM3v6pP8eZvZ022OQ9qh8ISKSEIWyiEhCFMoy9cdtD2BOTsPvcRp+Bzki7egTEUmIZsoiIglRKAtm9qtm9nUze9bMnjSzH2t7TIdlZu83s2+a2bfM7HLb4zksM3vYzD5vZs+b2XNm9pttj0naofKFYGbvBV5w9zfM7OeAT7j7e9oe10GZWQa8CLwPeBl4Cvhld3++1YEdgpk9CDzo7s+Y2Qrwd8CHT9LvIPOhmbLg7k+6+xvN1S8DD7U5niN4N/Atd/+2u4+BPwU+1PKYDsXdv+/uzzSX14EXgAvtjkraoFCWG30U+GzbgzikC8B3Z66/zAkONDN7BPhx4CstD0VaoB19ss3MfpoYyo+1PZazyszOAX8BfMzdr7c9Hjl+mimfUWb2G2b21ebrrWb2o8AngQ+5+2ttj++QrgEPz1x/qLntRDGzghjIf+Luf9n2eKQdOtAnmNnbgKvAr7n7k22P57DMLCce6PsZYhg/BfyKuz/X6sAOwcwM+DTwurt/rOXhSIsUyoKZfRL4BeA7zU3lSWvqY2YfAP4QyIBPufvvtTuiwzGzx4AvAs8CdXPzb7v738zh4fUiT4Md6E4KZZFTTy/yNBwolFVTFhFJiEJZRCQhCmURkYQolEVkLi5eucrFK1fbHsaJp80jInJHpkF8bW3Q8khOB82UJXlm9gkz+63m8u+a2c8e8XHUiW0Brq0NFMhzpJmynCju/jt38OMl8J9nO7GZ2efUiU1SopmyJMnMPm5mL5rZl4C3z9z+uJl9pLn8kpn9frNV/Gkz+wkz+1sz+wcz+/UbH1Od2OQk0ExZkmNm7wR+CXiU+Df6DLG/8F7+yd0fNbM/AB4HLgI94BvAH93i33gEdWKTBCmUJUU/BfyVu28BmNlf3+K+0+89C5xrZsDrZjYys1V3X7vxB9SJTVKm8oWcdKPmv/XM5en1myYd6sS2GBdW+1xY7bc9jFNBoSwp+gLwYTPrNwfkfn4eD9p0YvtvxFNf/dd5PKZET1y+xBOXL7U9jFNBoSzJaQ7G/RnwNeJZUJ6a00NfBP4jcGmml/QH5vTYInOhmrIkqWm9eVP7TXf/TzOXH5m5/DjxQN9N35u57UscsFOXSFs0UxYRSYhCWUQkIQplEZGEKJRF5MguXrmqpXBzpgN9InJk19YGvHTlg20P41TRTFlEJCEKZRGZmwurfTW6v0MKZRGZmycuX1Jv5TukUBYRSYhCWUQkIQplEZGEKJRFRBKiUBaRI9HGkcVQKIvIkVxbG+zZQ1nL4u6MQllE5krL4u6MQllEJCEKZRGRhCiURUQSolAWEUmIQllEJCEKZRGRhCiURUQSolAWEUmIQllEJCEKZRGRhCiURUQSolAWkUO7XYc4NSU6OoWyiBzafh3iptSU6OgUyiIiCVEoi4gkRKEsIpIQhbKISEIUyiIiCVEoi8hCaFnc0SiURWQhtCzuaBTKIiIJUSiLyKHcbjef3Jm87QGIyMlybW3AS1c+2PYwTi3NlEVEEqJQFhFJiEJZRCQhCmURkYQolEVkYbSB5PAUyiKyMNpAcngKZRGRhCiURUQSolAWkQPTbr7F044+ETkw7eZbPM2URUQSolAWEUmIQllEJCEKZRGRhCiURUQSolAWEUmIQllEJCEKZRGRhCiURUQSolAWkYVS+87DUSiLyEKpfefhKJRF5EDUjOh4qCGRiByImhEdD82URUQSolAWEUmIQllEbkv15OOjUBaR27q2NuCJy5eO/PNaFndwCmURWTgtizs4hbKISEIUyiJyLFTCOBiFsogcC5UwDkahLCKSEIWyiEhCFMoicktao3y81PtCRG5JPS+Ol2bKInJstALj9hTKIrKveZcutALj9lS+EJF9qXRx/DRTFhFJiEJZRPakVRftUPlCRPak0kU7NFMWkWOlFRi3plAWkZsssnQx7cusYN6buXvbYxCRxTrwi/zilatcWxtwYbV/R03tU/u3EmEHupNCWeTUu+WL/MYZ63EH5COXPwNwFsJZoSwictKopiwikhCFsohIQhTKIiIJUSiLiCREO/pETjkz+wYwbHscd+g+4NW2B3GHeu7+jtvdSaEscvoN3f1dbQ/iTpjZ06fhdzjI/VS+EBFJiEJZRCQhCmWR0++P2x7AHJyZ30E7+kREEqKZsohIQhTKIqeUmf2qmX3dzJ41syfN7MfaHtNhmdn7zeybZvYtM7vc9niOwsweNrPPm9nzZvacmf3mLe+v8oXI6WRm7wVecPc3zOzngE+4+3vaHtdBmVkGvAi8D3gZeAr4ZXd/vtWBHZKZPQg86O7PmNkK8HfAh/f7PTRTFjml3P1Jd3+jufpl4KE2x3ME7wa+5e7fdvcx8KfAh1oe06G5+/fd/Znm8jrwAnBhv/srlEXOho8Cn217EId0AfjuzPWXuUWYnQRm9gjw48BX9ruPdvSJnHJm9tPEUH6s7bGcZWZ2DvgL4GPufn2/+2mmLHKKmNlvmNlXm6+3mtmPAp8EPuTur7U9vkO6Bjw8c/2h5rYTx8wKYiD/ibv/5S3vqwN9IqeTmb0NuAr8mrs/2fZ4DsvMcuKBvp8hhvFTwK+4+3OtDuyQzMyATwOvu/vHbnt/hbLI6WRmnwR+AfhOc1N50pr6mNkHgD8EMuBT7v577Y7o8MzsMeCLwLNA3dz82+7+N3veX6EsIpIO1ZRFRBKiUBYRSYhCWUQkIQplEZGEKJRFRBKiUBaR1pjZJ8zst5rLv2tmP3vEx+mZ2f81s681ndj+y3xHeny0zVpEkuDuv3MHPz4CLrn7RrN77ktm9ll3//KchndsNFMWkWNlZh83sxfN7EvA22duf9zMPtJcfsnMfr/ZLv60mf2Emf2tmf2Dmf36jY/p0UZztWi+TuQmDIWyiBwbM3sn8EvAo8AHgJ+8xd3/yd0fJe6Gexz4CPDvgT1LE2aWmdlXgR8An3P3fTuxpUyhLCLH6aeAv3L3raZT2l/f4r7T7z0LfMXd1939FWBkZqs33tndqybEHwLebWbvmO/Qj4dCWURSNWr+W89cnl7f93iYu68Bnwfev7CRLZBCWUSO0xeAD5tZvzk10s/P40HN7C3T2bOZ9YmnkPr7eTz2cdPqCxE5Ns156v4M+Bqx9vvUnB76QeDTzXn9AvA/3f1/zemxj5W6xImIJETlCxGRhCiURUQSolAWEUmIQllEJCEKZRGRhCiURUQSolAWEUmIQllEJCH/HyiwHdMlwo1eAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" } ], "source": [ @@ -516,7 +186,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.9.5" } }, "nbformat": 4, From 6c2060c29af92bbd0326ba81d489d8b30fec7478 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 21:26:11 +0100 Subject: [PATCH 18/35] Using the loss of the density estimator for single round --- sbi/inference/snpe/snpe_a.py | 2 +- sbi/inference/snpe/snpe_b.py | 2 +- sbi/inference/snpe/snpe_base.py | 8 ++++---- sbi/inference/snpe/snpe_c.py | 7 ++++++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 5f8641da5..92a6ce0e5 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -307,7 +307,7 @@ def build_posterior( ) return deepcopy(self._posterior) - def _log_prob_proposal_posterior( + def _loss_proposal_posterior( self, theta: Tensor, x: Tensor, diff --git a/sbi/inference/snpe/snpe_b.py b/sbi/inference/snpe/snpe_b.py index 0246ccebd..dfe2f464a 100644 --- a/sbi/inference/snpe/snpe_b.py +++ b/sbi/inference/snpe/snpe_b.py @@ -39,7 +39,7 @@ def __init__( kwargs = del_entries(locals(), entries=("self", "__class__")) super().__init__(**kwargs) - def _log_prob_proposal_posterior( + def _loss_proposal_posterior( self, theta: Tensor, x: Tensor, masks: Tensor ) -> Tensor: """ diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index fbe9924cc..ac615216e 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -552,7 +552,7 @@ def build_posterior( return deepcopy(self._posterior) @abstractmethod - def _log_prob_proposal_posterior( + def _loss_proposal_posterior( self, theta: Tensor, x: Tensor, @@ -583,11 +583,11 @@ def _loss( """ if self._round == 0 or force_first_round_loss: # Use posterior log prob (without proposal correction) for first round. - log_prob = self._neural_net.log_prob(theta, x) + loss = self._neural_net.loss(theta, x) else: - log_prob = self._log_prob_proposal_posterior(theta, x, masks, proposal) + loss = self._loss_proposal_posterior(theta, x, masks, proposal) - return -(calibration_kernel(x) * log_prob) + return -(calibration_kernel(x) * loss) def _check_proposal(self, proposal): """ diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index 90145311e..c2632a274 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -254,7 +254,7 @@ def _set_maybe_z_scored_prior(self) -> None: else: self._maybe_z_scored_prior = self._prior - def _log_prob_proposal_posterior( + def _loss_proposal_posterior( self, theta: Tensor, x: Tensor, @@ -278,8 +278,13 @@ def _log_prob_proposal_posterior( """ if self.use_non_atomic_loss: + # TODO Add checks for mixture of gaussian return self._log_prob_proposal_posterior_mog(theta, x, proposal) else: + if not hasattr(self._neural_net, "log_prob"): + raise ValueError("The neural estimator must have a log_prob method, for\ + atomic loss. It should at best follow the \ + sbi.neural_nets 'DensityEstiamtor' interface.") return self._log_prob_proposal_posterior_atomic(theta, x, masks) def _log_prob_proposal_posterior_atomic( From 3cf6acd6a834c0c1548ea10720b72cc506961f1b Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 21:32:36 +0100 Subject: [PATCH 19/35] Formatting --- sbi/inference/snpe/snpe_a.py | 2 +- sbi/inference/snpe/snpe_c.py | 15 ++++++++++++--- sbi/samplers/rejection/rejection.py | 5 +---- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 92a6ce0e5..de878dcac 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -507,7 +507,7 @@ def _sample_approx_posterior_mog( embedded_context = torchutils.repeat_rows( embedded_context, num_reps=num_samples ) - # TODO + theta, _ = self._neural_net.net._transform.inverse( theta, context=embedded_context ) diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index c2632a274..ebaa07857 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -278,13 +278,22 @@ def _loss_proposal_posterior( """ if self.use_non_atomic_loss: - # TODO Add checks for mixture of gaussian + if not ( + hasattr(self._neural_net.net, "_distribution") + and isinstance(self._neural_net.net._distribution, mdn) + ): + raise ValueError( + "The density estimator must be a MDNtext for non-atomic loss." + ) + return self._log_prob_proposal_posterior_mog(theta, x, proposal) else: if not hasattr(self._neural_net, "log_prob"): - raise ValueError("The neural estimator must have a log_prob method, for\ + raise ValueError( + "The neural estimator must have a log_prob method, for\ atomic loss. It should at best follow the \ - sbi.neural_nets 'DensityEstiamtor' interface.") + sbi.neural_nets 'DensityEstiamtor' interface." + ) return self._log_prob_proposal_posterior_atomic(theta, x, masks) def _log_prob_proposal_posterior_atomic( diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index c0d672bd0..62d07618a 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -262,13 +262,10 @@ def accept_reject_sample( candidates = proposal.sample( (sampling_batch_size,), **proposal_sampling_kwargs # type: ignore ) # type: ignore - print(candidates.shape) + # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) - print(are_accepted.shape) samples = candidates[are_accepted] - print(samples.shape) - print(are_accepted.shape) accepted.append(samples) # Update. From 587321ec1d316a1d2c69eb7196cc655c657e5c03 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Tue, 12 Mar 2024 11:53:56 +0100 Subject: [PATCH 20/35] expose embedding_net property of density estimators --- sbi/neural_nets/density_estimators/base.py | 5 +++++ sbi/neural_nets/density_estimators/nflows_flow.py | 7 ++++++- sbi/neural_nets/density_estimators/zuko_flow.py | 6 ++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index a5174dc97..8dc67987a 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -31,6 +31,11 @@ def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: self.net = net self._condition_shape = condition_shape + @property + def embedding_net(self) -> Optional[nn.Module]: + r"""Return the embedding network if it exists.""" + return None + def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple i.e. batched conditions. diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 9ab94f2a9..aaeb62873 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -2,7 +2,7 @@ import torch from pyknos.nflows import flows -from torch import Tensor +from torch import Tensor, nn from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape @@ -15,6 +15,11 @@ class NFlowsFlow(DensityEstimator): wrap them and add the .loss() method. """ + @property + def embedding_net(self) -> nn.Module: + r"""Return the embedding network.""" + return self.net._embedding_net + def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple i.e. batched conditions. diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index a1c93970e..9cd933005 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -29,8 +29,10 @@ def __init__( super().__init__(net=net, condition_shape=condition_shape) self._embedding_net = embedding_net - def _maybe_z_score(self) -> bool: - return True + @property + def embedding_net(self) -> nn.Module: + r"""Return the embedding network.""" + return self._embedding_net def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple From e19c96e8b19c6d50611209dd0bffa3d752dea3c7 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Tue, 12 Mar 2024 11:55:19 +0100 Subject: [PATCH 21/35] add test for ZukoFlow for SNPE --- sbi/utils/get_nn_models.py | 4 ++-- tests/linearGaussian_snpe_test.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sbi/utils/get_nn_models.py b/sbi/utils/get_nn_models.py index 6ac836c01..0bd96b6d1 100644 --- a/sbi/utils/get_nn_models.py +++ b/sbi/utils/get_nn_models.py @@ -175,8 +175,6 @@ def build_fn(batch_theta, batch_x): elif model == "mnle": return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs) elif model == "zuko_maf": - del kwargs["num_bins"] - del kwargs["num_components"] return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) else: raise NotImplementedError @@ -277,6 +275,8 @@ def build_fn(batch_theta, batch_x): return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs) elif model == "nsf": return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) + elif model == "zuko_maf": + return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) else: raise NotImplementedError diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 001fd1a00..d021a4c96 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -147,8 +147,10 @@ def test_c2st_snpe_on_linearGaussian(snpe_method, num_dim: int, prior_str: str): @pytest.mark.slow -@pytest.mark.parametrize("density_estimtor", ["mdn", "maf", "maf_rqs", "nsf"]) -def test_density_estimators_on_linearGaussian(density_estimtor): +@pytest.mark.parametrize( + "density_estimator", ["mdn", "maf", "maf_rqs", "nsf", "zuko_maf"] +) +def test_density_estimators_on_linearGaussian(density_estimator): """Test SNPE with different density estimators on linear Gaussian example.""" theta_dim = 4 @@ -176,7 +178,7 @@ def test_density_estimators_on_linearGaussian(density_estimtor): prior, ) - inference = SNPE_C(prior, density_estimator=density_estimtor) + inference = SNPE_C(prior, density_estimator=density_estimator) theta, x = simulate_for_sbi( simulator, prior, num_simulations, simulation_batch_size=1000 @@ -190,7 +192,7 @@ def test_density_estimators_on_linearGaussian(density_estimtor): samples = posterior.sample((num_samples,)) # Compute the c2st and assert it is near chance level of 0.5. - check_c2st(samples, target_samples, alg=f"snpe_{density_estimtor}") + check_c2st(samples, target_samples, alg=f"snpe_{density_estimator}") def test_c2st_snpe_on_linearGaussian_different_dims(density_estimator="maf"): From 05b78cf5db0eec4360656732d4bcb1b3af2da826 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Mon, 11 Mar 2024 22:43:56 +0100 Subject: [PATCH 22/35] remove sign snpe bug --- sbi/inference/snpe/snpe_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index ac615216e..0a92a188b 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -587,7 +587,7 @@ def _loss( else: loss = self._loss_proposal_posterior(theta, x, masks, proposal) - return -(calibration_kernel(x) * loss) + return (calibration_kernel(x) * loss) def _check_proposal(self, proposal): """ From ecefd8597059bc5378357a6642dd8f877abd2b73 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 12 Mar 2024 15:24:40 +0100 Subject: [PATCH 23/35] New ruff formating --- sbi/analysis/conditional_density.py | 2 +- sbi/inference/posteriors/direct_posterior.py | 4 +--- sbi/inference/potentials/likelihood_based_potential.py | 10 ++++++---- sbi/inference/potentials/posterior_based_potential.py | 6 ++---- sbi/inference/snle/mnle.py | 5 +++-- sbi/inference/snle/snle_base.py | 3 +-- sbi/inference/snpe/snpe_a.py | 2 -- sbi/inference/snpe/snpe_base.py | 2 +- sbi/neural_nets/density_estimators/base.py | 2 +- sbi/neural_nets/density_estimators/nflows_flow.py | 3 +-- sbi/neural_nets/density_estimators/zuko_flow.py | 2 +- sbi/neural_nets/flow.py | 4 +++- sbi/neural_nets/mnle.py | 1 - 13 files changed, 21 insertions(+), 25 deletions(-) diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 278b6be4c..655a15f17 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -7,7 +7,7 @@ import torch import torch.distributions.transforms as torch_tf from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn -from torch import Tensor, nn +from torch import Tensor from torch.distributions import Distribution from sbi.neural_nets.density_estimators.base import DensityEstimator diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 5a292c357..6073474ce 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -3,7 +3,6 @@ from typing import Optional, Union import torch -from pyknos.nflows import flows from torch import Tensor, log from torch.distributions import Distribution @@ -14,8 +13,7 @@ from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.samplers.rejection.rejection import accept_reject_sample from sbi.types import Shape -from sbi.utils import check_prior, match_theta_and_x_batch_shapes, within_support -from sbi.utils.torchutils import ensure_theta_batched +from sbi.utils import check_prior, within_support class DirectPosterior(NeuralPosterior): diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 8a160df91..31c6c0854 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -1,10 +1,10 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . -from typing import Any, Callable, Optional, Tuple +from typing import Callable, Optional, Tuple import torch -from torch import Tensor, nn +from torch import Tensor from torch.distributions import Distribution from sbi.inference.potentials.base_potential import BasePotential @@ -133,7 +133,8 @@ def _log_likelihoods_over_trials( ), "x and theta must match in batch shape." assert ( next(estimator.parameters()).device == x.device and x.device == theta.device - ), f"""device mismatch: estimator, x, theta: {next(estimator.parameters()).device}, {x.device}, + ), f"""device mismatch: estimator, x, theta: \ + {next(estimator.parameters()).device}, {x.device}, {theta.device}.""" # Calculate likelihood in one batch. @@ -185,7 +186,8 @@ def __init__( x_o: Optional[Tensor], device: str = "cpu", ): - # TODO Fix pyright issue by making MixedDensityEstimator a subclass of DensityEstimator + # TODO Fix pyright issue by making MixedDensityEstimator a subclass + # of DensityEstimator super().__init__(likelihood_estimator, prior, x_o, device) # type: ignore def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 14bd4fb41..6648c3949 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -4,16 +4,14 @@ from typing import Callable, Optional, Tuple import torch -from pyknos.nflows import flows -from torch import Tensor, nn +from torch import Tensor from torch.distributions import Distribution from sbi.inference.potentials.base_potential import BasePotential from sbi.neural_nets.density_estimators import DensityEstimator from sbi.types import TorchTransform from sbi.utils import mcmc_transform -from sbi.utils.sbiutils import match_theta_and_x_batch_shapes, within_support -from sbi.utils.torchutils import ensure_theta_batched +from sbi.utils.sbiutils import within_support def posterior_estimator_based_potential( diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index 60a249cd1..4b6dc1c1b 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -5,7 +5,7 @@ from copy import deepcopy from typing import Any, Callable, Dict, Optional, Union -from torch import Tensor, nn, optim +from torch import Tensor from torch.distributions import Distribution from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior @@ -196,7 +196,8 @@ def build_posterior( return deepcopy(self._posterior) - # Temporary: need to rewrite mixed likelihood estimators as DensityEstimator objects. + # Temporary: need to rewrite mixed likelihood estimators as DensityEstimator + # objects. def _loss(self, theta: Tensor, x: Tensor) -> Tensor: r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$. diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 299f2418e..7cfb713b2 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -7,8 +7,7 @@ from typing import Any, Callable, Dict, Optional, Union import torch -from pyknos.nflows import flows -from torch import Tensor, nn, optim +from torch import Tensor, optim from torch.distributions import Distribution from torch.nn.utils.clip_grad import clip_grad_norm_ from torch.utils.tensorboard.writer import SummaryWriter diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index d36314bd0..c8bcefc7f 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -7,9 +7,7 @@ from typing import Any, Callable, Dict, Optional, Union import torch -import torch.nn as nn from pyknos.mdn.mdn import MultivariateGaussianMDN -from pyknos.nflows import flows from pyknos.nflows.transforms import CompositeTransform from torch import Tensor from torch.distributions import Distribution, MultivariateNormal diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index dc153d19c..1eefee146 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -7,7 +7,7 @@ from warnings import warn import torch -from torch import Tensor, nn, ones, optim +from torch import Tensor, ones, optim from torch.distributions import Distribution from torch.nn.utils.clip_grad import clip_grad_norm_ from torch.utils.tensorboard.writer import SummaryWriter diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 7f54efbe3..8dc67987a 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch from torch import Tensor, nn diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 83472c631..a79b535fd 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -1,8 +1,7 @@ from typing import Tuple import torch -from pyknos.nflows.flows import Flow -from torch import Tensor +from torch import Tensor, nn from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index 9cd933005..7cbb240a9 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -2,7 +2,7 @@ import torch from torch import Tensor, nn -from zuko.flows import Flow, LazyComposedTransform +from zuko.flows import Flow from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 2d8bcc3df..ef12c0f66 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -7,6 +7,7 @@ from warnings import warn import torch +import zuko from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets @@ -468,7 +469,8 @@ def build_zuko_maf( check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) y_numel = embedding_net(batch_y[:1]).numel() if x_numel == 1: - warn("In one-dimensional output space, this flow is limited to Gaussians") + warn("In one-dimensional output space, this flow is limited to Gaussians" + , stacklevel=1) if isinstance(hidden_features, int): hidden_features = [hidden_features] * num_transforms diff --git a/sbi/neural_nets/mnle.py b/sbi/neural_nets/mnle.py index e62c3144c..3ee0c89bc 100644 --- a/sbi/neural_nets/mnle.py +++ b/sbi/neural_nets/mnle.py @@ -10,7 +10,6 @@ from torch.distributions import Categorical from torch.nn import Sigmoid, Softmax -from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.neural_nets.flow import build_nsf from sbi.utils.sbiutils import match_theta_and_x_batch_shapes, standardizing_net from sbi.utils.torchutils import atleast_2d From a3325b6d30f40ef98ec79cfd81ea77b5b979abd4 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 12 Mar 2024 15:34:37 +0100 Subject: [PATCH 24/35] ruff-format --- sbi/inference/snpe/snpe_a.py | 1 - sbi/inference/snpe/snpe_base.py | 2 +- sbi/neural_nets/flow.py | 6 ++++-- sbi/samplers/rejection/rejection.py | 3 ++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index c8bcefc7f..47cc9298b 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -419,7 +419,6 @@ def __init__( self._set_state_for_mog_proposal() def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor: - inputs, condition = inputs.to(self._device), condition.to(self._device) if not self._apply_correction: diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 1eefee146..461fbb15d 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -585,7 +585,7 @@ def _loss( else: loss = self._loss_proposal_posterior(theta, x, masks, proposal) - return (calibration_kernel(x) * loss) + return calibration_kernel(x) * loss def _check_proposal(self, proposal): """ diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index ef12c0f66..3e7bd8d56 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -469,8 +469,10 @@ def build_zuko_maf( check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) y_numel = embedding_net(batch_y[:1]).numel() if x_numel == 1: - warn("In one-dimensional output space, this flow is limited to Gaussians" - , stacklevel=1) + warn( + "In one-dimensional output space, this flow is limited to Gaussians", + stacklevel=1, + ) if isinstance(hidden_features, int): hidden_features = [hidden_features] * num_transforms diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index a0c363d55..ae016ea1c 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -263,7 +263,8 @@ def accept_reject_sample( while num_remaining > 0: # Sample and reject. candidates = proposal.sample( - (sampling_batch_size,), **proposal_sampling_kwargs # type: ignore + (sampling_batch_size,), + **proposal_sampling_kwargs, # type: ignore ) # type: ignore # SNPE-style rejection-sampling when the proposal is the neural net. From 63ab9a1237b433ec6278309fde00e0d1c5e4bd49 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 12 Mar 2024 15:47:43 +0100 Subject: [PATCH 25/35] Ruff suggestions --- sbi/samplers/rejection/rejection.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index ae016ea1c..59bbbe71f 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -257,6 +257,9 @@ def accept_reject_sample( num_sampled_total, num_remaining = 0, num_samples accepted, acceptance_rate = [], float("Nan") leakage_warning_raised = False + # Ruff suggestion + if proposal_sampling_kwargs is None: + proposal_sampling_kwargs = {} # To cover cases with few samples without leakage: sampling_batch_size = min(num_samples, max_sampling_batch_size) @@ -264,7 +267,7 @@ def accept_reject_sample( # Sample and reject. candidates = proposal.sample( (sampling_batch_size,), - **proposal_sampling_kwargs, # type: ignore + **proposal_sampling_kwargs, ) # type: ignore # SNPE-style rejection-sampling when the proposal is the neural net. From 55117043aa8dd8dfa7f896ae61cce798e8e82cdc Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 12 Mar 2024 15:50:50 +0100 Subject: [PATCH 26/35] pyright fix --- sbi/samplers/rejection/rejection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 59bbbe71f..015b3e66c 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -266,9 +266,9 @@ def accept_reject_sample( while num_remaining > 0: # Sample and reject. candidates = proposal.sample( - (sampling_batch_size,), + (sampling_batch_size,), # type: ignore **proposal_sampling_kwargs, - ) # type: ignore + ) # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) From b3d8ba4e8335ca277dedc236ba31cbdf9695b869 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Tue, 12 Mar 2024 15:53:52 +0100 Subject: [PATCH 27/35] Final formatting --- sbi/samplers/rejection/rejection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 015b3e66c..0a437aa80 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -266,9 +266,9 @@ def accept_reject_sample( while num_remaining > 0: # Sample and reject. candidates = proposal.sample( - (sampling_batch_size,), # type: ignore + (sampling_batch_size,), # type: ignore **proposal_sampling_kwargs, - ) + ) # SNPE-style rejection-sampling when the proposal is the neural net. are_accepted = accept_reject_fn(candidates) From 56260e11f04a47aace4f10a2a021fba920bdbc27 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Tue, 12 Mar 2024 18:03:30 +0100 Subject: [PATCH 28/35] add RatioEstimator --- sbi/inference/snre/snre_c.py | 2 +- sbi/neural_nets/ratio_estimators/base.py | 98 +++++++++++++++++++ .../ratio_estimators/classifier.py | 67 +++++++++++++ 3 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 sbi/neural_nets/ratio_estimators/base.py create mode 100644 sbi/neural_nets/ratio_estimators/classifier.py diff --git a/sbi/inference/snre/snre_c.py b/sbi/inference/snre/snre_c.py index fe98cc054..ad99c39ca 100644 --- a/sbi/inference/snre/snre_c.py +++ b/sbi/inference/snre/snre_c.py @@ -147,7 +147,7 @@ def _loss( assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." batch_size = theta.shape[0] - # We append an contrastive theta to the marginal case because we will remove + # We append a contrastive theta to the marginal case because we will remove # the jointly drawn # sample in the logits_marginal[:, 0] position. That makes the remaining sample # marginally drawn. diff --git a/sbi/neural_nets/ratio_estimators/base.py b/sbi/neural_nets/ratio_estimators/base.py new file mode 100644 index 000000000..4cae2ed00 --- /dev/null +++ b/sbi/neural_nets/ratio_estimators/base.py @@ -0,0 +1,98 @@ +import torch +from torch import Tensor, nn + + +class RatioEstimator(nn.Module): + r"""Base class for ratio estimators. + + The ratio estimator class is a wrapper around neural networks that + allows to evaluate the classifier logits of $\theta,x$ pairs. Here $\theta$ would + be the `input` and $x$ would be the `condition`. + + Note: + We assume that the input to the ratio estimator is a tensor of shape + (batch_size, input_size), where input_size is the dimensionality of the input. + The condition is a tensor of shape (batch_size, *condition_shape), where + condition_shape is the shape of the condition tensor. + + """ + + def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: + r"""Base class for ratio estimators. + + Args: + net: Neural network. + condition_shape: Shape of the condition. If not provided, it will assume a + 1D input. + """ + super().__init__() + self.net = net + self._condition_shape = condition_shape + + def forward(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: + r"""Return the logits of the batched (input,condition) pairs. + + Args: + input: Inputs to evaluate the log probability on of shape + (*batch_shape, input_size). + condition: Conditions of shape (*batch_shape, *condition_shape). + + Raises: + RuntimeError: If batch_shapes don't match. + + Returns: + Sample-wise logits. + + Note: + This function should support PyTorch's automatic broadcasting. This means + the function should behave as follows for different input and condition + shapes: + - (input_size,) + (batch_size,*condition_shape) -> (batch_size,) + - (batch_size, input_size) + (*condition_shape) -> (batch_size,) + - (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,) + - (batch_size1, input_size) + (batch_size2, *condition_shape) + -> RuntimeError i.e. not broadcastable + """ + + raise NotImplementedError + + def loss(self, input: Tensor, condition: Tensor, labels, **kwargs) -> Tensor: + r"""Return the loss for training the ratio estimator. + + Args: + input: Inputs to evaluate the loss on of shape (batch_size, input_size). + condition: Conditions of shape (batch_size, *condition_shape). + labels: Labels of shape (batch_size,). + + Returns: + Loss of shape (batch_size,) + """ + + raise NotImplementedError + + def _check_condition_shape(self, condition: Tensor): + r"""This method checks whether the condition has the correct shape. + + Args: + condition: Conditions of shape (*batch_shape, *condition_shape). + + Raises: + ValueError: If the condition has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the condition does not match the expected + input dimensionality. + """ + if len(condition.shape) < len(self._condition_shape): + raise ValueError( + f"Dimensionality of condition is to small and does not match the\ + expected input dimensionality {len(self._condition_shape)}, as provided\ + by condition_shape." + ) + else: + condition_shape = condition.shape[-len(self._condition_shape) :] + if tuple(condition_shape) != tuple(self._condition_shape): + raise ValueError( + f"Shape of condition {tuple(condition_shape)} does not match the \ + expected input dimensionality {tuple(self._condition_shape)}, as \ + provided by condition_shape. Please reshape it accordingly." + ) diff --git a/sbi/neural_nets/ratio_estimators/classifier.py b/sbi/neural_nets/ratio_estimators/classifier.py new file mode 100644 index 000000000..5ff8896cc --- /dev/null +++ b/sbi/neural_nets/ratio_estimators/classifier.py @@ -0,0 +1,67 @@ +import torch +from torch import Tensor, nn + +from sbi.neural_nets.ratio_estimators.base import RatioEstimator + + +class ClassifierRatio(RatioEstimator): + r"""classifier- based density ratio estimator.""" + + def forward(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: + r"""Return the logits of the batched (input,condition) pairs. + + Args: + input: Inputs to evaluate the log probability on of shape + (*batch_shape, input_size). + condition: Conditions of shape (*batch_shape, *condition_shape). + + Raises: + RuntimeError: If batch_shapes don't match. + + Returns: + Sample-wise logits. + + Note: + This function should support PyTorch's automatic broadcasting. This means + the function should behave as follows for different input and condition + shapes: + - (batch_size, input_size) + (*condition_shape) -> (batch_size,) + - (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,) + - (batch_size1, input_size) + (batch_size2, *condition_shape) + -> RuntimeError i.e. not broadcastable + """ + + self._check_condition_shape(condition) + condition_dims = len(self._condition_shape) + + # PyTorch's automatic broadcasting + batch_shape_in = input.shape[:-1] + batch_shape_cond = condition.shape[:-condition_dims] + assert batch_shape_cond == batch_shape_in or len(batch_shape_cond) == 0, ( + "Batch shapes don't match. " + f"input: {input.shape}, condition: {condition.shape}." + ) + + condition = condition.expand(batch_shape_in + self._condition_shape) + # Flatten required by nflows, but now both have the same batch shape + input = input.reshape(-1, input.shape[-1]) + condition = condition.reshape(-1, *self._condition_shape) + + logits = self.net([input, condition]) + return logits + + def loss(self, input: Tensor, condition: Tensor, labels, **kwargs) -> Tensor: + r"""Return the loss for training the ratio estimator. + + Args: + input: Inputs to evaluate the loss on of shape (batch_size, input_size). + condition: Conditions of shape (batch_size, *condition_shape). + labels: Labels of shape (batch_size,). + + Returns: + Loss of shape (batch_size,) + """ + logits = self.forward(input, condition) + likelihood = torch.sigmoid(logits).squeeze() + # Binary cross entropy to learn the likelihood + return nn.BCELoss()(likelihood, labels) From 787e24bab8bb39c5d1be24eb8310e9bcd901dc55 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Wed, 13 Mar 2024 17:27:01 +0100 Subject: [PATCH 29/35] Recommended changes --- sbi/analysis/conditional_density.py | 6 +++-- sbi/inference/posteriors/direct_posterior.py | 25 ++++++++++++++++--- .../potentials/posterior_based_potential.py | 18 +++++++------ sbi/inference/snle/mnle.py | 1 + sbi/inference/snpe/snpe_a.py | 24 +++++++++++++++++- sbi/inference/snpe/snpe_b.py | 2 +- sbi/inference/snpe/snpe_base.py | 4 +-- sbi/inference/snpe/snpe_c.py | 2 +- .../density_estimators/nflows_flow.py | 4 +++ tests/embedding_net_test.py | 4 +-- 10 files changed, 70 insertions(+), 20 deletions(-) diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 655a15f17..7f0f9cde2 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -186,7 +186,7 @@ def conditional_corrcoeff( class ConditionedMDN: def __init__( self, - net: DensityEstimator, # TODO: Must be MDN!!! + gaussianMDN: DensityEstimator, x_o: Tensor, condition: Tensor, dims_to_sample: List[int], @@ -206,7 +206,9 @@ def __init__( """ condition = atleast_2d_float32_tensor(condition) - logits, means, precfs, _ = extract_and_transform_mog(net=net.net, context=x_o) + logits, means, precfs, _ = extract_and_transform_mog( + net=gaussianMDN.net, context=x_o + ) self.logits, self.means, self.precfs, self.sumlogdiag = condition_mog( condition, dims_to_sample, logits, means, precfs ) diff --git a/sbi/inference/posteriors/direct_posterior.py b/sbi/inference/posteriors/direct_posterior.py index 6073474ce..8d454e1c6 100644 --- a/sbi/inference/posteriors/direct_posterior.py +++ b/sbi/inference/posteriors/direct_posterior.py @@ -14,6 +14,7 @@ from sbi.samplers.rejection.rejection import accept_reject_sample from sbi.types import Shape from sbi.utils import check_prior, within_support +from sbi.utils.torchutils import ensure_theta_batched class DirectPosterior(NeuralPosterior): @@ -100,7 +101,18 @@ def sample( """ num_samples = torch.Size(sample_shape).numel() + condition_shape = self.posterior_estimator._condition_shape x = self._x_else_default_x(x) + + try: + x = x.reshape(*condition_shape) + except RuntimeError as err: + raise ValueError( + f"Expected a single `x` which should broadcastable to shape \ + {condition_shape}, but got {x.shape}. For batched eval \ + see issue #990" + ) from err + max_sampling_batch_size = ( self.max_sampling_batch_size if max_sampling_batch_size is None @@ -123,7 +135,6 @@ def sample( proposal_sampling_kwargs={"condition": x}, alternative_method="build_posterior(..., sample_with='mcmc')", )[0] - samples = samples.view(sample_shape + (-1,)) # TODO: Why is this necessary? return samples @@ -160,12 +171,20 @@ def log_prob( support of the prior, -∞ (corresponding to 0 probability) outside. """ x = self._x_else_default_x(x) + condition_shape = self.posterior_estimator._condition_shape + try: + x = x.reshape(*condition_shape) + except RuntimeError as err: + raise ValueError( + f"Expected a single `x` which should broadcastable to shape \ + {condition_shape}, but got {x.shape}. For batched eval \ + see issue #990" + ) from err # TODO Train exited here, entered after sampling? self.posterior_estimator.eval() - # theta = ensure_theta_batched(torch.as_tensor(theta)) - # theta_repeated, x_repeated = match_theta_and_x_batch_shapes(theta, x) + theta = ensure_theta_batched(torch.as_tensor(theta)) with torch.set_grad_enabled(track_gradients): # Evaluate on device, move back to cpu for comparison with prior. diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 6648c3949..27fcc6324 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -12,6 +12,7 @@ from sbi.types import TorchTransform from sbi.utils import mcmc_transform from sbi.utils.sbiutils import within_support +from sbi.utils.torchutils import ensure_theta_batched def posterior_estimator_based_potential( @@ -91,16 +92,17 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: The potential. """ - # NOTE: This is no longer necessary, as the `log_prob` will broadcast - # theta = ensure_theta_batched(torch.as_tensor(theta)) - # theta, x_repeated = match_theta_and_x_batch_shapes(theta, self.x_o) - # theta, x_repeated = theta.to(self.device), x_repeated.to(self.device) - assert self._x_o is not None, "No observed data is available." + if self._x_o is None: + raise ValueError( + "No observed data x_o is available. Please reinitialize \ + the potential or manually set self._x_o." + ) + + theta = ensure_theta_batched(torch.as_tensor(theta)) + theta, x = theta.to(self.device), self.x_o.to(self.device) with torch.set_grad_enabled(track_gradients): - posterior_log_prob = self.posterior_estimator.log_prob( - theta, condition=self._x_o - ) + posterior_log_prob = self.posterior_estimator.log_prob(theta, condition=x) # Force probability to be zero outside prior support. in_prior_support = within_support(self.prior, theta) diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index 4b6dc1c1b..57ac0a425 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -198,6 +198,7 @@ def build_posterior( # Temporary: need to rewrite mixed likelihood estimators as DensityEstimator # objects. + # TODO: Fix and merge issue #968 def _loss(self, theta: Tensor, x: Tensor) -> Tensor: r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$. diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 47cc9298b..e010a944d 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -306,7 +306,7 @@ def build_posterior( ) return deepcopy(self._posterior) - def _loss_proposal_posterior( + def _log_prob_proposal_posterior( self, theta: Tensor, x: Tensor, @@ -363,6 +363,9 @@ def _expand_mog(self, eps: float = 1e-5): class SNPE_A_MDN(DensityEstimator): """Generates a posthoc-corrected MDN which approximates the posterior. + TODO: Adapt this class to the new `DensityEstimator` interface. Maybe even to a + common MDN interface. See #989. + This class takes as input the density estimator (abbreviated with `_d` suffix, aka the proposal posterior) and the proposal prior (abbreviated with `_pp` suffix) from which the simulations were drawn. It uses the algorithm presented in SNPE-A [1] to @@ -419,6 +422,15 @@ def __init__( self._set_state_for_mog_proposal() def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor: + """_summary_ + + Args: + inputs: Input values + condition: Condition values + + Returns: + log_prob p(inputs|condition) + """ inputs, condition = inputs.to(self._device), condition.to(self._device) if not self._apply_correction: @@ -444,6 +456,16 @@ def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor: return log_prob_proposal_posterior # \hat{p} from eq (3) in [1] def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: + """Sample from the approximate posterior. + + Args: + sample_shape: Shape of the samples. + condition: Condition values. + + Returns: + Samples from the approximate posterior. + """ + condition = condition.to(self._device) if not self._apply_correction: diff --git a/sbi/inference/snpe/snpe_b.py b/sbi/inference/snpe/snpe_b.py index dfe2f464a..0246ccebd 100644 --- a/sbi/inference/snpe/snpe_b.py +++ b/sbi/inference/snpe/snpe_b.py @@ -39,7 +39,7 @@ def __init__( kwargs = del_entries(locals(), entries=("self", "__class__")) super().__init__(**kwargs) - def _loss_proposal_posterior( + def _log_prob_proposal_posterior( self, theta: Tensor, x: Tensor, masks: Tensor ) -> Tensor: """ diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 461fbb15d..17341843c 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -550,7 +550,7 @@ def build_posterior( return deepcopy(self._posterior) @abstractmethod - def _loss_proposal_posterior( + def _log_prob_proposal_posterior( self, theta: Tensor, x: Tensor, @@ -583,7 +583,7 @@ def _loss( # Use posterior log prob (without proposal correction) for first round. loss = self._neural_net.loss(theta, x) else: - loss = self._loss_proposal_posterior(theta, x, masks, proposal) + loss = self._log_prob_proposal_posterior(theta, x, masks, proposal) return calibration_kernel(x) * loss diff --git a/sbi/inference/snpe/snpe_c.py b/sbi/inference/snpe/snpe_c.py index ebaa07857..a368183b5 100644 --- a/sbi/inference/snpe/snpe_c.py +++ b/sbi/inference/snpe/snpe_c.py @@ -254,7 +254,7 @@ def _set_maybe_z_scored_prior(self) -> None: else: self._maybe_z_scored_prior = self._prior - def _loss_proposal_posterior( + def _log_prob_proposal_posterior( self, theta: Tensor, x: Tensor, diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index a79b535fd..ad578d96c 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -1,6 +1,7 @@ from typing import Tuple import torch +from pyknos.nflows.flows import Flow from torch import Tensor, nn from sbi.neural_nets.density_estimators.base import DensityEstimator @@ -14,6 +15,9 @@ class NFlowsFlow(DensityEstimator): wrap them and add the .loss() method. """ + def __init__(self, net: Flow, condition_shape: torch.Size) -> None: + super().__init__(net, condition_shape) + @property def embedding_net(self) -> nn.Module: r"""Return the embedding network.""" diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index 6d568608c..5b05deb65 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -59,7 +59,7 @@ def test_embedding_net_api(method, num_dim: int, embedding_net: str): else: raise NameError - _ = inference.append_simulations(theta, x).train(max_num_epochs=5) + _ = inference.append_simulations(theta, x).train(max_num_epochs=2) posterior = inference.build_posterior( mcmc_method="slice_np_vectorized", mcmc_parameters=dict(num_chains=2, warmup_steps=10, thin=5), @@ -94,7 +94,7 @@ def test_embedding_api_with_multiple_trials(num_trials, num_dim): inference = SNPE(prior, density_estimator=density_estimator) _ = inference.append_simulations(theta, x).train(max_num_epochs=5) - # Increased to 5 as otherwise 99.99 rejection rate for SNPE + posterior = inference.build_posterior().set_default_x(x_o) s = posterior.sample((1,)) From 1dee434bf38eaa544f9b48bce8c1b4363eb5b42e Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Wed, 13 Mar 2024 20:27:30 +0100 Subject: [PATCH 30/35] remaining changes --- .../potentials/likelihood_based_potential.py | 27 +++++-------------- sbi/neural_nets/density_estimators/base.py | 4 +-- tests/density_estimator_test.py | 20 +++++--------- 3 files changed, 16 insertions(+), 35 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 31c6c0854..d31ddb6bb 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -12,8 +12,7 @@ from sbi.neural_nets.mnle import MixedDensityEstimator from sbi.types import TorchTransform from sbi.utils import mcmc_transform -from sbi.utils.sbiutils import match_theta_and_x_batch_shapes -from sbi.utils.torchutils import atleast_2d +from sbi.utils.torchutils import ensure_x_batched def likelihood_estimator_based_potential( @@ -122,28 +121,16 @@ def _log_likelihoods_over_trials( batch entries (iid trials) in `x`. """ - # Repeat `x` in case of evaluation on multiple `theta`. This is needed below in - # when calling nflows in order to have matching shapes of theta and context x - # at neural network evaluation time. - theta_repeated, x_repeated = match_theta_and_x_batch_shapes( - theta=atleast_2d(theta), x=atleast_2d(x) - ) - assert ( - x_repeated.shape[0] == theta_repeated.shape[0] - ), "x and theta must match in batch shape." - assert ( - next(estimator.parameters()).device == x.device and x.device == theta.device - ), f"""device mismatch: estimator, x, theta: \ - {next(estimator.parameters()).device}, {x.device}, - {theta.device}.""" + x = ensure_x_batched(x) + theta_batch_size = theta.shape[: -len(estimator._condition_shape)] # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): - log_likelihood_trial_batch = estimator.log_prob(x_repeated, theta_repeated) - # Reshape to (x-trials x parameters), sum over trial-log likelihoods. + log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) + # Reshape to (theta_batch_size * num_trials), sum over trial-log likelihoods. log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( - x.shape[0], -1 - ).sum(0) + theta_batch_size + (-1,) + ).sum(-1) return log_likelihood_trial_sum diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index 8dc67987a..a2af32bc1 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -24,8 +24,8 @@ def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: Args: net: Neural network. - condition_shape: Shape of the input. If not provided, it will assume a 1D - input. + condition_shape: Shape of the condition. If not provided, it will assume a + 1D input. """ super().__init__() self.net = net diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index b73a9fa3f..4d38ebb03 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -7,10 +7,9 @@ import torch from torch import eye, zeros from torch.distributions import MultivariateNormal -from zuko.flows import NSF from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow -from sbi.neural_nets.flow import build_nsf +from sbi.neural_nets.flow import build_nsf, build_zuko_maf @pytest.mark.parametrize("density_estimator", (NFlowsFlow, ZukoFlow)) @@ -54,17 +53,12 @@ def forward(self, x): embedding_net=EmbeddingNet(), ) elif density_estimator == ZukoFlow: - # if len(condition_shape) > 1: - # pytest.skip("ZukoFlow does not support multi-dimensional contexts.") - net = NSF( - features=input_dims, - context=condition_shape[-1], - transforms=2, - hidden_features=(10,), - bins=8, - ) - estimator = density_estimator( - net, embedding_net=EmbeddingNet(), condition_shape=condition_shape + estimator = build_zuko_maf( + batch_input, + batch_context, + hidden_features=10, + num_transforms=2, + embedding_net=EmbeddingNet(), ) # Loss is only required to work for batched inputs and contexts From c9b18160ee2fb7d0180b91a0a48533c5a76f5f7b Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Wed, 13 Mar 2024 20:29:09 +0100 Subject: [PATCH 31/35] remove RatioEstimator from this PR --- sbi/neural_nets/ratio_estimators/base.py | 98 ------------------- .../ratio_estimators/classifier.py | 67 ------------- 2 files changed, 165 deletions(-) delete mode 100644 sbi/neural_nets/ratio_estimators/base.py delete mode 100644 sbi/neural_nets/ratio_estimators/classifier.py diff --git a/sbi/neural_nets/ratio_estimators/base.py b/sbi/neural_nets/ratio_estimators/base.py deleted file mode 100644 index 4cae2ed00..000000000 --- a/sbi/neural_nets/ratio_estimators/base.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch -from torch import Tensor, nn - - -class RatioEstimator(nn.Module): - r"""Base class for ratio estimators. - - The ratio estimator class is a wrapper around neural networks that - allows to evaluate the classifier logits of $\theta,x$ pairs. Here $\theta$ would - be the `input` and $x$ would be the `condition`. - - Note: - We assume that the input to the ratio estimator is a tensor of shape - (batch_size, input_size), where input_size is the dimensionality of the input. - The condition is a tensor of shape (batch_size, *condition_shape), where - condition_shape is the shape of the condition tensor. - - """ - - def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: - r"""Base class for ratio estimators. - - Args: - net: Neural network. - condition_shape: Shape of the condition. If not provided, it will assume a - 1D input. - """ - super().__init__() - self.net = net - self._condition_shape = condition_shape - - def forward(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: - r"""Return the logits of the batched (input,condition) pairs. - - Args: - input: Inputs to evaluate the log probability on of shape - (*batch_shape, input_size). - condition: Conditions of shape (*batch_shape, *condition_shape). - - Raises: - RuntimeError: If batch_shapes don't match. - - Returns: - Sample-wise logits. - - Note: - This function should support PyTorch's automatic broadcasting. This means - the function should behave as follows for different input and condition - shapes: - - (input_size,) + (batch_size,*condition_shape) -> (batch_size,) - - (batch_size, input_size) + (*condition_shape) -> (batch_size,) - - (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,) - - (batch_size1, input_size) + (batch_size2, *condition_shape) - -> RuntimeError i.e. not broadcastable - """ - - raise NotImplementedError - - def loss(self, input: Tensor, condition: Tensor, labels, **kwargs) -> Tensor: - r"""Return the loss for training the ratio estimator. - - Args: - input: Inputs to evaluate the loss on of shape (batch_size, input_size). - condition: Conditions of shape (batch_size, *condition_shape). - labels: Labels of shape (batch_size,). - - Returns: - Loss of shape (batch_size,) - """ - - raise NotImplementedError - - def _check_condition_shape(self, condition: Tensor): - r"""This method checks whether the condition has the correct shape. - - Args: - condition: Conditions of shape (*batch_shape, *condition_shape). - - Raises: - ValueError: If the condition has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the condition does not match the expected - input dimensionality. - """ - if len(condition.shape) < len(self._condition_shape): - raise ValueError( - f"Dimensionality of condition is to small and does not match the\ - expected input dimensionality {len(self._condition_shape)}, as provided\ - by condition_shape." - ) - else: - condition_shape = condition.shape[-len(self._condition_shape) :] - if tuple(condition_shape) != tuple(self._condition_shape): - raise ValueError( - f"Shape of condition {tuple(condition_shape)} does not match the \ - expected input dimensionality {tuple(self._condition_shape)}, as \ - provided by condition_shape. Please reshape it accordingly." - ) diff --git a/sbi/neural_nets/ratio_estimators/classifier.py b/sbi/neural_nets/ratio_estimators/classifier.py deleted file mode 100644 index 5ff8896cc..000000000 --- a/sbi/neural_nets/ratio_estimators/classifier.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -from torch import Tensor, nn - -from sbi.neural_nets.ratio_estimators.base import RatioEstimator - - -class ClassifierRatio(RatioEstimator): - r"""classifier- based density ratio estimator.""" - - def forward(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: - r"""Return the logits of the batched (input,condition) pairs. - - Args: - input: Inputs to evaluate the log probability on of shape - (*batch_shape, input_size). - condition: Conditions of shape (*batch_shape, *condition_shape). - - Raises: - RuntimeError: If batch_shapes don't match. - - Returns: - Sample-wise logits. - - Note: - This function should support PyTorch's automatic broadcasting. This means - the function should behave as follows for different input and condition - shapes: - - (batch_size, input_size) + (*condition_shape) -> (batch_size,) - - (batch_size, input_size) + (batch_size, *condition_shape) -> (batch_size,) - - (batch_size1, input_size) + (batch_size2, *condition_shape) - -> RuntimeError i.e. not broadcastable - """ - - self._check_condition_shape(condition) - condition_dims = len(self._condition_shape) - - # PyTorch's automatic broadcasting - batch_shape_in = input.shape[:-1] - batch_shape_cond = condition.shape[:-condition_dims] - assert batch_shape_cond == batch_shape_in or len(batch_shape_cond) == 0, ( - "Batch shapes don't match. " - f"input: {input.shape}, condition: {condition.shape}." - ) - - condition = condition.expand(batch_shape_in + self._condition_shape) - # Flatten required by nflows, but now both have the same batch shape - input = input.reshape(-1, input.shape[-1]) - condition = condition.reshape(-1, *self._condition_shape) - - logits = self.net([input, condition]) - return logits - - def loss(self, input: Tensor, condition: Tensor, labels, **kwargs) -> Tensor: - r"""Return the loss for training the ratio estimator. - - Args: - input: Inputs to evaluate the loss on of shape (batch_size, input_size). - condition: Conditions of shape (batch_size, *condition_shape). - labels: Labels of shape (batch_size,). - - Returns: - Loss of shape (batch_size,) - """ - logits = self.forward(input, condition) - likelihood = torch.sigmoid(logits).squeeze() - # Binary cross entropy to learn the likelihood - return nn.BCELoss()(likelihood, labels) From 763ded1e08305f7631833cdfcc8290137b28065f Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Wed, 13 Mar 2024 21:37:02 +0100 Subject: [PATCH 32/35] fix test failure in likelihood_potential --- .../potentials/likelihood_based_potential.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index d31ddb6bb..3e7a46e7e 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -12,7 +12,6 @@ from sbi.neural_nets.mnle import MixedDensityEstimator from sbi.types import TorchTransform from sbi.utils import mcmc_transform -from sbi.utils.torchutils import ensure_x_batched def likelihood_estimator_based_potential( @@ -120,17 +119,15 @@ def _log_likelihoods_over_trials( log_likelihood_trial_sum: log likelihood for each parameter, summed over all batch entries (iid trials) in `x`. """ - - x = ensure_x_batched(x) - theta_batch_size = theta.shape[: -len(estimator._condition_shape)] + # unsqueeze to ensure that the x-batch dimension is the first dimension for the + # broadcasting of the density estimator. + x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1) # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): log_likelihood_trial_batch = estimator.log_prob(x, condition=theta) - # Reshape to (theta_batch_size * num_trials), sum over trial-log likelihoods. - log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( - theta_batch_size + (-1,) - ).sum(-1) + # Reshape to (-1, theta_batch_size), sum over trial-log likelihoods. + log_likelihood_trial_sum = log_likelihood_trial_batch.sum(0) return log_likelihood_trial_sum From 4cbdaaef388cceb33ea5cd0c5143f6c0049909a0 Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Thu, 14 Mar 2024 10:54:30 +0100 Subject: [PATCH 33/35] device check in likelihood_based_potential --- sbi/inference/potentials/likelihood_based_potential.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 3e7a46e7e..cb431fb4a 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -122,6 +122,11 @@ def _log_likelihoods_over_trials( # unsqueeze to ensure that the x-batch dimension is the first dimension for the # broadcasting of the density estimator. x = torch.as_tensor(x).reshape(-1, x.shape[-1]).unsqueeze(1) + assert ( + next(estimator.parameters()).device == x.device and x.device == theta.device + ), f"""device mismatch: estimator, x, theta: \ + {next(estimator.parameters()).device}, {x.device}, + {theta.device}.""" # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): From 65af0f49a6c268ec8c087b7c3faddfdbd6490b86 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Thu, 14 Mar 2024 12:31:50 +0100 Subject: [PATCH 34/35] Recommended changes --- sbi/analysis/conditional_density.py | 14 +++++++------- sbi/inference/snpe/snpe_a.py | 6 +++--- sbi/inference/snpe/snpe_base.py | 4 +++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 7f0f9cde2..3a7f9a010 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -6,7 +6,7 @@ import torch import torch.distributions.transforms as torch_tf -from pyknos.mdn.mdn import MultivariateGaussianMDN as mdn +from pyknos.mdn.mdn import MultivariateGaussianMDN from torch import Tensor from torch.distributions import Distribution @@ -186,7 +186,7 @@ def conditional_corrcoeff( class ConditionedMDN: def __init__( self, - gaussianMDN: DensityEstimator, + mdn: DensityEstimator, x_o: Tensor, condition: Tensor, dims_to_sample: List[int], @@ -206,9 +206,7 @@ def __init__( """ condition = atleast_2d_float32_tensor(condition) - logits, means, precfs, _ = extract_and_transform_mog( - net=gaussianMDN.net, context=x_o - ) + logits, means, precfs, _ = extract_and_transform_mog(net=mdn.net, context=x_o) self.logits, self.means, self.precfs, self.sumlogdiag = condition_mog( condition, dims_to_sample, logits, means, precfs ) @@ -216,13 +214,15 @@ def __init__( def sample(self, sample_shape: Shape = torch.Size()) -> Tensor: num_samples = torch.Size(sample_shape).numel() - samples = mdn.sample_mog(num_samples, self.logits, self.means, self.precfs) + samples = MultivariateGaussianMDN.sample_mog( + num_samples, self.logits, self.means, self.precfs + ) return samples.detach().reshape((*sample_shape, -1)) def log_prob(self, theta: Tensor) -> Tensor: batch_size, dim = theta.shape - log_prob = mdn.log_prob_mog( + log_prob = MultivariateGaussianMDN.log_prob_mog( theta, self.logits.repeat(batch_size, 1), self.means.repeat(batch_size, 1, 1), diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index e010a944d..36fedcdd3 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -422,7 +422,7 @@ def __init__( self._set_state_for_mog_proposal() def log_prob(self, inputs: Tensor, condition: Tensor, **kwargs) -> Tensor: - """_summary_ + """Compute the log-probability of the approximate posterior. Args: inputs: Input values @@ -549,9 +549,9 @@ def _posthoc_correction(self, x: Tensor): """ # Evaluate the density estimator. - encoded_x = self._neural_net.net._embedding_net(x) + embedded_x = self._neural_net.net._embedding_net(x) dist = self._neural_net.net._distribution # defined to avoid black formatting. - logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(encoded_x) + logits_d, m_d, prec_d, _, _ = dist.get_mixture_components(embedded_x) norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True) # The following if case is needed because, in the constructor, we call diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 17341843c..394139e64 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -583,7 +583,9 @@ def _loss( # Use posterior log prob (without proposal correction) for first round. loss = self._neural_net.loss(theta, x) else: - loss = self._log_prob_proposal_posterior(theta, x, masks, proposal) + # Currently only works for `DensityEstimator` objects. + # Must be extended ones other Estimators are implemented. See #966, + loss = -self._log_prob_proposal_posterior(theta, x, masks, proposal) return calibration_kernel(x) * loss From b4b00ba01e89b71ebd89325fcf565f6f78e7a629 Mon Sep 17 00:00:00 2001 From: manuelgloeckler Date: Fri, 15 Mar 2024 02:15:02 +0100 Subject: [PATCH 35/35] Integrated last recommendations --- sbi/samplers/rejection/rejection.py | 3 +++ sbi/utils/conditional_density_utils.py | 6 +++++- sbi/utils/sbiutils.py | 10 +--------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 0a437aa80..5c78cd8dd 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -276,6 +276,9 @@ def accept_reject_sample( accepted.append(samples) # Update. + # Note: For any condition of shape (*batch_shape, *condition_shape), the + # samples will be of shape(*batch_shape, sampling_batch_size, d) and hence work + # in dim = -2. num_sampled_total += sampling_batch_size num_remaining -= samples.shape[-2] pbar.update(samples.shape[-2]) diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index b2cc536c4..5f8df5d38 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -294,12 +294,16 @@ def __init__( self.device = self.potential_fn.device self.allow_iid_x = allow_iid_x - def __call__(self, theta: Tensor, track_gradients: bool = True, x_o=None) -> Tensor: + def __call__( + self, theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True + ) -> Tensor: r""" Returns the conditional potential $\log(p(\theta_i|\theta_j, x))$. Args: theta: Free parameters $\theta_i$, batch dimension 1. + x_o: Unused keyword argument. Only present to match the signature of the + `Potential` class. track_gradients: Whether to track gradients. Returns: diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 5e5385b1d..4f9bcd980 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -572,15 +572,7 @@ def within_support(distribution: Any, samples: Tensor) -> Tensor: try: sample_check = distribution.support.check(samples) return sample_check - # Before torch v1.7.0, `support.check()` returned bools for every element. - # From v1.8.0 on, it directly considers all dimensions of a sample. E.g., - # for a single sample in 3D, v1.7.0 would return [[True, True, True]] and - # v1.8.0 would return [True]. - # TODO Check - # if sample_check.ndim > 1: - # return torch.all(sample_check, dim=-1) - # else: - # return sample_check + # Falling back to log prob method of either the NeuralPosterior's net, or of a # custom wrapper distribution's. except (NotImplementedError, AttributeError):