From fdd6c05c9ee96fc97d92decbd0255099119fb43a Mon Sep 17 00:00:00 2001 From: Guy Moss Date: Thu, 7 Mar 2024 15:08:14 +0100 Subject: [PATCH] snle_a uses DensityEstimator instead of nn.Module --- sbi/inference/base.py | 10 +++---- .../potentials/likelihood_based_potential.py | 23 ++++++++-------- sbi/inference/snle/mnle.py | 10 +++++++ sbi/inference/snle/snle_base.py | 27 ++++++++++--------- 4 files changed, 41 insertions(+), 29 deletions(-) diff --git a/sbi/inference/base.py b/sbi/inference/base.py index a7353465d..d7cc01506 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -122,7 +122,7 @@ def __init__( self._prior = prior self._posterior = None - self._neural_net = None + self._density_estimator = None self._x_shape = None self._show_progress_bars = show_progress_bars @@ -349,20 +349,20 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """ converged = False - assert self._neural_net is not None - neural_net = self._neural_net + assert self._density_estimator is not None + density_estimator = self._density_estimator # (Re)-start the epoch count with the first epoch or any improvement. if epoch == 0 or self._val_log_prob > self._best_val_log_prob: self._best_val_log_prob = self._val_log_prob self._epochs_since_last_improvement = 0 - self._best_model_state_dict = deepcopy(neural_net.state_dict()) + self._best_model_state_dict = deepcopy(density_estimator.state_dict()) else: self._epochs_since_last_improvement += 1 # If no validation improvement over many epochs, stop training. if self._epochs_since_last_improvement > stop_after_epochs - 1: - neural_net.load_state_dict(self._best_model_state_dict) + density_estimator.load_state_dict(self._best_model_state_dict) converged = True return converged diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index d3aa5e51d..d337ee98f 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -13,10 +13,11 @@ from sbi.utils import mcmc_transform from sbi.utils.sbiutils import match_theta_and_x_batch_shapes from sbi.utils.torchutils import atleast_2d +from sbi.neural_nets.density_estimators import DensityEstimator def likelihood_estimator_based_potential( - likelihood_estimator: nn.Module, + likelihood_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], enable_transform: bool = True, @@ -27,7 +28,7 @@ def likelihood_estimator_based_potential( unconstrained space. Args: - likelihood_estimator: The neural network modelling the likelihood. + likelihood_estimator: The density estimator modelling the likelihood. prior: The prior distribution. x_o: The observed data at which to evaluate the likelihood. enable_transform: Whether to transform parameters to unconstrained space. @@ -55,7 +56,7 @@ class LikelihoodBasedPotential(BasePotential): def __init__( self, - likelihood_estimator: nn.Module, + likelihood_estimator: DensityEstimator, prior: Distribution, x_o: Optional[Tensor], device: str = "cpu", @@ -63,7 +64,7 @@ def __init__( r"""Returns the potential function for likelihood-based methods. Args: - likelihood_estimator: The neural network modelling the likelihood. + likelihood_estimator: The density estimator modelling the likelihood. prior: The prior distribution. x_o: The observed data at which to evaluate the likelihood. device: The device to which parameters and data are moved before evaluating @@ -92,7 +93,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: log_likelihood_trial_sum = _log_likelihoods_over_trials( x=self.x_o, theta=theta.to(self.device), - net=self.likelihood_estimator, + estimator=self.likelihood_estimator, track_gradients=track_gradients, ) @@ -100,7 +101,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: def _log_likelihoods_over_trials( - x: Tensor, theta: Tensor, net: Any, track_gradients: bool = False + x: Tensor, theta: Tensor, estimator: DensityEstimator, track_gradients: bool = False ) -> Tensor: r"""Return log likelihoods summed over iid trials of `x`. @@ -112,8 +113,8 @@ def _log_likelihoods_over_trials( Args: x: batch of iid data. - theta: batch of parameters - net: neural net with .log_prob() + theta: batch of parameters. + estimator: DensityEstimator. track_gradients: Whether to track gradients. Returns: @@ -131,13 +132,13 @@ def _log_likelihoods_over_trials( x_repeated.shape[0] == theta_repeated.shape[0] ), "x and theta must match in batch shape." assert ( - next(net.parameters()).device == x.device and x.device == theta.device - ), f"""device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device}, + next(estimator.parameters()).device == x.device and x.device == theta.device + ), f"""device mismatch: estimator, x, theta: {next(estimator.parameters()).device}, {x.device}, {theta.device}.""" # Calculate likelihood in one batch. with torch.set_grad_enabled(track_gradients): - log_likelihood_trial_batch = net.log_prob(x_repeated, theta_repeated) + log_likelihood_trial_batch = estimator.log_prob(x_repeated, theta_repeated) # Reshape to (x-trials x parameters), sum over trial-log likelihoods. log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( x.shape[0], -1 diff --git a/sbi/inference/snle/mnle.py b/sbi/inference/snle/mnle.py index 33b4ffe1e..57f0906a2 100644 --- a/sbi/inference/snle/mnle.py +++ b/sbi/inference/snle/mnle.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Optional, Union from torch.distributions import Distribution +from torch import Tensor, nn, optim from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior from sbi.inference.potentials import mixed_likelihood_estimator_based_potential @@ -193,3 +194,12 @@ def build_posterior( self._model_bank.append(deepcopy(self._posterior)) return deepcopy(self._posterior) + + #Temporary: need to rewrite mixed likelihood estimators as DensityEstimator objects. + def _loss(self, theta: Tensor, x: Tensor) -> Tensor: + r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$. + + Returns: + Negative log prob. + """ + return -self._density_estimator.log_prob(x, context=theta) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index f4a27a6d0..4f44bea00 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -18,6 +18,7 @@ from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior from sbi.inference.potentials import likelihood_estimator_based_potential from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation +from sbi.neural_nets.density_estimators import DensityEstimator class LikelihoodEstimator(NeuralInference, ABC): @@ -126,7 +127,7 @@ def train( retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, - ) -> flows.Flow: + ) -> DensityEstimator: r"""Train the density estimator to learn the distribution $p(x|\theta)$. Args: @@ -165,11 +166,11 @@ def train( # arguments, which will build the neural network # This is passed into NeuralPosterior, to create a neural posterior which # can `sample()` and `log_prob()`. The network is accessible via `.net`. - if self._neural_net is None or retrain_from_scratch: + if self._density_estimator is None or retrain_from_scratch: # Get theta,x to initialize NN theta, x, _ = self.get_simulations(starting_round=start_idx) # Use only training data for building the neural net (z-scoring transforms) - self._neural_net = self._build_neural_net( + self._density_estimator = self._build_neural_net( theta[self.train_indices].to("cpu"), x[self.train_indices].to("cpu"), ) @@ -179,10 +180,10 @@ def train( len(self._x_shape) < 3 ), "SNLE cannot handle multi-dimensional simulator output." - self._neural_net.to(self._device) + self._density_estimator.to(self._device) if not resume_training: self.optimizer = optim.Adam( - list(self._neural_net.parameters()), + list(self._density_estimator.parameters()), lr=learning_rate, ) self.epoch, self._val_log_prob = 0, float("-Inf") @@ -191,7 +192,7 @@ def train( self.epoch, stop_after_epochs ): # Train for a single epoch. - self._neural_net.train() + self._density_estimator.train() train_log_probs_sum = 0 for batch in train_loader: self.optimizer.zero_grad() @@ -207,7 +208,7 @@ def train( train_loss.backward() if clip_max_norm is not None: clip_grad_norm_( - self._neural_net.parameters(), + self._density_estimator.parameters(), max_norm=clip_max_norm, ) self.optimizer.step() @@ -220,7 +221,7 @@ def train( self._summary["training_log_probs"].append(train_log_prob_average) # Calculate validation performance. - self._neural_net.eval() + self._density_estimator.eval() val_log_prob_sum = 0 with torch.no_grad(): for batch in val_loader: @@ -256,13 +257,13 @@ def train( # Avoid keeping the gradients in the resulting network, which can # cause memory leakage when benchmarking. - self._neural_net.zero_grad(set_to_none=True) + self._density_estimator.zero_grad(set_to_none=True) - return deepcopy(self._neural_net) + return deepcopy(self._density_estimator) def build_posterior( self, - density_estimator: Optional[nn.Module] = None, + density_estimator: Optional[DensityEstimator] = None, prior: Optional[Distribution] = None, sample_with: str = "mcmc", mcmc_method: str = "slice_np", @@ -311,7 +312,7 @@ def build_posterior( check_prior(prior) if density_estimator is None: - likelihood_estimator = self._neural_net + likelihood_estimator = self._density_estimator # If internal net is used device is defined. device = self._device else: @@ -367,4 +368,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor: Returns: Negative log prob. """ - return -self._neural_net.log_prob(x, context=theta) + return self._density_estimator.loss(x, condition=theta)