Skip to content

Commit

Permalink
snle_a uses DensityEstimator instead of nn.Module
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Mar 7, 2024
1 parent 5440446 commit fdd6c05
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 29 deletions.
10 changes: 5 additions & 5 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self._prior = prior

self._posterior = None
self._neural_net = None
self._density_estimator = None
self._x_shape = None

self._show_progress_bars = show_progress_bars
Expand Down Expand Up @@ -349,20 +349,20 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
"""
converged = False

assert self._neural_net is not None
neural_net = self._neural_net
assert self._density_estimator is not None
density_estimator = self._density_estimator

# (Re)-start the epoch count with the first epoch or any improvement.
if epoch == 0 or self._val_log_prob > self._best_val_log_prob:
self._best_val_log_prob = self._val_log_prob
self._epochs_since_last_improvement = 0
self._best_model_state_dict = deepcopy(neural_net.state_dict())
self._best_model_state_dict = deepcopy(density_estimator.state_dict())
else:
self._epochs_since_last_improvement += 1

# If no validation improvement over many epochs, stop training.
if self._epochs_since_last_improvement > stop_after_epochs - 1:
neural_net.load_state_dict(self._best_model_state_dict)
density_estimator.load_state_dict(self._best_model_state_dict)
converged = True

return converged
Expand Down
23 changes: 12 additions & 11 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from sbi.utils import mcmc_transform
from sbi.utils.sbiutils import match_theta_and_x_batch_shapes
from sbi.utils.torchutils import atleast_2d
from sbi.neural_nets.density_estimators import DensityEstimator


def likelihood_estimator_based_potential(
likelihood_estimator: nn.Module,
likelihood_estimator: DensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
enable_transform: bool = True,
Expand All @@ -27,7 +28,7 @@ def likelihood_estimator_based_potential(
unconstrained space.
Args:
likelihood_estimator: The neural network modelling the likelihood.
likelihood_estimator: The density estimator modelling the likelihood.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood.
enable_transform: Whether to transform parameters to unconstrained space.
Expand Down Expand Up @@ -55,15 +56,15 @@ class LikelihoodBasedPotential(BasePotential):

def __init__(
self,
likelihood_estimator: nn.Module,
likelihood_estimator: DensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
):
r"""Returns the potential function for likelihood-based methods.
Args:
likelihood_estimator: The neural network modelling the likelihood.
likelihood_estimator: The density estimator modelling the likelihood.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood.
device: The device to which parameters and data are moved before evaluating
Expand Down Expand Up @@ -92,15 +93,15 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
log_likelihood_trial_sum = _log_likelihoods_over_trials(
x=self.x_o,
theta=theta.to(self.device),
net=self.likelihood_estimator,
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)

return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore


def _log_likelihoods_over_trials(
x: Tensor, theta: Tensor, net: Any, track_gradients: bool = False
x: Tensor, theta: Tensor, estimator: DensityEstimator, track_gradients: bool = False
) -> Tensor:
r"""Return log likelihoods summed over iid trials of `x`.
Expand All @@ -112,8 +113,8 @@ def _log_likelihoods_over_trials(
Args:
x: batch of iid data.
theta: batch of parameters
net: neural net with .log_prob()
theta: batch of parameters.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.
Returns:
Expand All @@ -131,13 +132,13 @@ def _log_likelihoods_over_trials(
x_repeated.shape[0] == theta_repeated.shape[0]
), "x and theta must match in batch shape."
assert (
next(net.parameters()).device == x.device and x.device == theta.device
), f"""device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device},
next(estimator.parameters()).device == x.device and x.device == theta.device
), f"""device mismatch: estimator, x, theta: {next(estimator.parameters()).device}, {x.device},
{theta.device}."""

# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
log_likelihood_trial_batch = net.log_prob(x_repeated, theta_repeated)
log_likelihood_trial_batch = estimator.log_prob(x_repeated, theta_repeated)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
x.shape[0], -1
Expand Down
10 changes: 10 additions & 0 deletions sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Callable, Dict, Optional, Union

from torch.distributions import Distribution
from torch import Tensor, nn, optim

from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
Expand Down Expand Up @@ -193,3 +194,12 @@ def build_posterior(
self._model_bank.append(deepcopy(self._posterior))

return deepcopy(self._posterior)

#Temporary: need to rewrite mixed likelihood estimators as DensityEstimator objects.
def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
r"""Return loss for SNLE, which is the likelihood of $-\log q(x_i | \theta_i)$.
Returns:
Negative log prob.
"""
return -self._density_estimator.log_prob(x, context=theta)
27 changes: 14 additions & 13 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation
from sbi.neural_nets.density_estimators import DensityEstimator


class LikelihoodEstimator(NeuralInference, ABC):
Expand Down Expand Up @@ -126,7 +127,7 @@ def train(
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> flows.Flow:
) -> DensityEstimator:
r"""Train the density estimator to learn the distribution $p(x|\theta)$.
Args:
Expand Down Expand Up @@ -165,11 +166,11 @@ def train(
# arguments, which will build the neural network
# This is passed into NeuralPosterior, to create a neural posterior which
# can `sample()` and `log_prob()`. The network is accessible via `.net`.
if self._neural_net is None or retrain_from_scratch:
if self._density_estimator is None or retrain_from_scratch:
# Get theta,x to initialize NN
theta, x, _ = self.get_simulations(starting_round=start_idx)
# Use only training data for building the neural net (z-scoring transforms)
self._neural_net = self._build_neural_net(
self._density_estimator = self._build_neural_net(
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
)
Expand All @@ -179,10 +180,10 @@ def train(
len(self._x_shape) < 3
), "SNLE cannot handle multi-dimensional simulator output."

self._neural_net.to(self._device)
self._density_estimator.to(self._device)
if not resume_training:
self.optimizer = optim.Adam(
list(self._neural_net.parameters()),
list(self._density_estimator.parameters()),
lr=learning_rate,
)
self.epoch, self._val_log_prob = 0, float("-Inf")
Expand All @@ -191,7 +192,7 @@ def train(
self.epoch, stop_after_epochs
):
# Train for a single epoch.
self._neural_net.train()
self._density_estimator.train()
train_log_probs_sum = 0
for batch in train_loader:
self.optimizer.zero_grad()
Expand All @@ -207,7 +208,7 @@ def train(
train_loss.backward()
if clip_max_norm is not None:
clip_grad_norm_(
self._neural_net.parameters(),
self._density_estimator.parameters(),
max_norm=clip_max_norm,
)
self.optimizer.step()
Expand All @@ -220,7 +221,7 @@ def train(
self._summary["training_log_probs"].append(train_log_prob_average)

# Calculate validation performance.
self._neural_net.eval()
self._density_estimator.eval()
val_log_prob_sum = 0
with torch.no_grad():
for batch in val_loader:
Expand Down Expand Up @@ -256,13 +257,13 @@ def train(

# Avoid keeping the gradients in the resulting network, which can
# cause memory leakage when benchmarking.
self._neural_net.zero_grad(set_to_none=True)
self._density_estimator.zero_grad(set_to_none=True)

return deepcopy(self._neural_net)
return deepcopy(self._density_estimator)

def build_posterior(
self,
density_estimator: Optional[nn.Module] = None,
density_estimator: Optional[DensityEstimator] = None,
prior: Optional[Distribution] = None,
sample_with: str = "mcmc",
mcmc_method: str = "slice_np",
Expand Down Expand Up @@ -311,7 +312,7 @@ def build_posterior(
check_prior(prior)

if density_estimator is None:
likelihood_estimator = self._neural_net
likelihood_estimator = self._density_estimator
# If internal net is used device is defined.
device = self._device
else:
Expand Down Expand Up @@ -367,4 +368,4 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
return -self._neural_net.log_prob(x, context=theta)
return self._density_estimator.loss(x, condition=theta)

0 comments on commit fdd6c05

Please sign in to comment.