Skip to content

Commit

Permalink
docstrings and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Feb 14, 2025
1 parent 0d29c8a commit d6e4133
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 76 deletions.
53 changes: 5 additions & 48 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from typing import Dict, Optional, Union

import torch
Expand All @@ -10,6 +9,7 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.score_based_potential import (
CallableDifferentiablePotentialFunction,
PosteriorScoreBasedPotential,
score_estimator_based_potential,
)
Expand Down Expand Up @@ -113,7 +113,6 @@ def sample(
Args:
sample_shape: Shape of the samples to be drawn.
x: Deprecated - use `.set_default_x()` prior to `.sample()`.
predictor: The predictor for the diffusion-based sampler. Can be a string or
a custom predictor following the API in `sbi.samplers.score.predictors`.
Currently, only `euler_maruyama` is implemented.
Expand Down Expand Up @@ -150,9 +149,7 @@ def sample(
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"x": x},
)[0]
samples = samples.reshape(sample_shape + self.score_estimator.input_shape)
elif self.sample_with == "sde":
proposal_sampling_kwargs = {
"predictor": predictor,
Expand All @@ -172,14 +169,13 @@ def sample(
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs=proposal_sampling_kwargs,
)[0]
samples = samples.reshape(sample_shape + self.score_estimator.input_shape)

samples = samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def _sample_via_diffusion(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
predictor: Union[str, Predictor] = "euler_maruyama",
corrector: Optional[Union[str, Corrector]] = None,
predictor_params: Optional[Dict] = None,
Expand All @@ -193,7 +189,6 @@ def _sample_via_diffusion(
Args:
sample_shape: Shape of the samples to be drawn.
x: Deprecated - use `.set_default_x()` prior to `.sample()`.
predictor: The predictor for the diffusion-based sampler. Can be a string or
a custom predictor following the API in `sbi.samplers.score.predictors`.
Currently, only `euler_maruyama` is implemented.
Expand Down Expand Up @@ -249,7 +244,6 @@ def _sample_via_diffusion(
def sample_via_zuko(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
) -> Tensor:
r"""Return samples from posterior distribution with probability flow ODE.
Expand All @@ -265,10 +259,9 @@ def sample_via_zuko(
"""
num_samples = torch.Size(sample_shape).numel()

x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)

flow = self.potential_fn.get_continuous_normalizing_flow(condition=x)
flow = self.potential_fn.get_continuous_normalizing_flow(
condition=self.potential_fn.x_o
)
samples = flow.sample(torch.Size((num_samples,)))

return samples
Expand Down Expand Up @@ -468,39 +461,3 @@ def map(
)[0]

return self._map

Check warning on line 463 in sbi/inference/posteriors/score_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/score_posterior.py#L463

Added line #L463 was not covered by tests


class DifferentiablePotentialFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, call_function, gradient_function):
# Save the methods as callables
ctx.call_function = call_function
ctx.gradient_function = gradient_function
ctx.save_for_backward(input)

# Perform the forward computation
output = call_function(input)
return output

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad = ctx.gradient_function(input)
while len(grad_output.shape) < len(grad.shape):
grad_output = grad_output.unsqueeze(-1)
grad_input = grad_output * grad
return grad_input, None, None


# Wrapper class to manage state
class CallableDifferentiablePotentialFunction:
def __init__(self, posterior_score_based_potential):
self.posterior_score_based_potential = posterior_score_based_potential

def __call__(self, input):
prepared_potential = partial(
self.posterior_score_based_potential.__call__, rebuild_flow=False
)
return DifferentiablePotentialFunction.apply(
input, prepared_potential, self.posterior_score_based_potential.gradient
)
112 changes: 90 additions & 22 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -80,21 +80,20 @@ def set_x(
x_is_iid: Optional[bool] = False,
rebuild_flow: Optional[bool] = True,
):
"""
Set the observed data and whether it is IID.
Args:
x_o: The observed data.
x_is_iid: Whether the observed data is IID (if batch_dim>1).
rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if
the flow needs to be evaluated many times (e.g. for MAP calculation).
"""
super().set_x(x_o, x_is_iid)
if rebuild_flow and self._x_o is not None:
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1 or not self.x_is_iid, (
"PosteriorScoreBasedPotential does not support IID observations`."
)
# For large number of evals, we want a high-tolerance flow.
# By default, we want a high-tolerance flow.
# This flow will be used mainly for MAP calculations, hence we want to save
# it instead of rebuilding it every time.
flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=1e-2, rtol=1e-3, exact=True
)
self.flow = flow
self.flow = self.rebuild_flow(atol=1e-2, rtol=1e-3, exact=True)

def __call__(
self,
Expand Down Expand Up @@ -123,17 +122,10 @@ def __call__(
theta, theta.shape[1:], leading_is_sample=True
)
self.score_estimator.eval()
# use rebuild_flow to evaluate log_prob with better precision, without
# overwriting self.flow
if rebuild_flow or self.flow is None:
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1, (
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
)

flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
)
flow = self.rebuild_flow(atol=atol, rtol=rtol, exact=exact)
else:
flow = self.flow

Check warning on line 130 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L130

Added line #L130 was not covered by tests

Expand Down Expand Up @@ -209,6 +201,30 @@ def get_continuous_normalizing_flow(
# Use zuko to build the normalizing flow.
return NormalizingFlow(transform, base=base_density)

def rebuild_flow(
self, atol: float = 1e-5, rtol: float = 1e-6, exact: bool = True
) -> NormalizingFlow:
"""
Rebuilds the continuous normalizing flow. This is used when
a new default x is set, or to evaluate the log probs at higher precision.
"""
if self._x_o is None:
raise ValueError(

Check warning on line 212 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L212

Added line #L212 was not covered by tests
"No observed data x_o is available. Please reinitialize \
the potential or manually set self._x_o."
)
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1, (
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
)

flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
)
return flow


def build_freeform_jacobian_transform(
score_estimator: ConditionalScoreEstimator,
Expand Down Expand Up @@ -249,3 +265,55 @@ def f(t, x):
exact=exact,
)
return transform


class DifferentiablePotentialFunction(torch.autograd.Function):
"""
A wrapper of PosteriorScoreBasedPotential with a custom autograd function to compute
the gradient of log_prob with respect to theta. Instead of backpropagating through
the continuous normalizing flow, we use the gradient of the score estimator.
"""

@staticmethod
def forward(ctx, input, call_function, gradient_function):
"""
Computes the potential normally.
"""
# Save the methods as callables
ctx.call_function = call_function
ctx.gradient_function = gradient_function
ctx.save_for_backward(input)

Check warning on line 286 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L284-L286

Added lines #L284 - L286 were not covered by tests

# Perform the forward computation
output = call_function(input)
return output

Check warning on line 290 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L289-L290

Added lines #L289 - L290 were not covered by tests

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad = ctx.gradient_function(input)

Check warning on line 295 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L294-L295

Added lines #L294 - L295 were not covered by tests
# Match dims
while len(grad_output.shape) < len(grad.shape):
grad_output = grad_output.unsqueeze(-1)
grad_input = grad_output * grad
return grad_input, None, None

Check warning on line 300 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L297-L300

Added lines #L297 - L300 were not covered by tests


class CallableDifferentiablePotentialFunction:
"""
This class handles the forward and backward functions from the potential function
that can be passed to DifferentiablePotentialFunction, as torch.autograd.Function
only supports static methods, and so it can't be given the potential class directly.
"""

def __init__(self, posterior_score_based_potential):
self.posterior_score_based_potential = posterior_score_based_potential

Check warning on line 311 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L311

Added line #L311 was not covered by tests

def __call__(self, input):
prepared_potential = partial(

Check warning on line 314 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L314

Added line #L314 was not covered by tests
self.posterior_score_based_potential.__call__, rebuild_flow=False
)
return DifferentiablePotentialFunction.apply(

Check warning on line 317 in sbi/inference/potentials/score_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/score_based_potential.py#L317

Added line #L317 was not covered by tests
input, prepared_potential, self.posterior_score_based_potential.gradient
)
15 changes: 11 additions & 4 deletions sbi/inference/trainers/npse/npse.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def train(
dataloader_kwargs: Optional[dict] = None,
) -> ConditionalScoreEstimator:
r"""Returns a score estimator that approximates the score
$\nabla_\theta \log p(\theta|x)$.
$\nabla_\theta \log p(\theta|x)$. The denoising score matching loss has a high
variance, which makes it more difficult to detect converegence. To reduce this
variance, we evaluate the validation loss at a fixed set of times. We also use
the exponential moving average of the training and validation losses, as opposed
to the other `trainer` classes, which track the loss directly.
Args:
training_batch_size: Training batch size.
Expand Down Expand Up @@ -358,6 +362,12 @@ def default_calibration_kernel(x):
# moving average of the training loss.
if len(self._summary["training_loss"]) == 0:
self._summary["training_loss"].append(train_loss_average)
else:
previous_loss = self._summary["training_loss"][-1]
self._summary["training_loss"].append(
(1.0 - ema_loss_decay) * previous_loss
+ ema_loss_decay * train_loss_average
)

# Calculate validation performance.
self._neural_net.eval()
Expand Down Expand Up @@ -400,16 +410,13 @@ def default_calibration_kernel(x):
force_first_round_loss=force_first_round_loss,
)

# print("val_losses: ", val_losses.shape)
val_loss_sum += val_losses.sum().item()

# Take mean over all validation samples.
val_loss = val_loss_sum / (
len(val_loader) * val_loader.batch_size * times_batch # type: ignore
)

# NOTE: Due to the inherently noisy nature we do instead log a exponential
# moving average of the validation loss.
if len(self._summary["validation_loss"]) == 0:
val_loss_ema = val_loss
else:
Expand Down
6 changes: 4 additions & 2 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ def accept_reject_sample(
density during evaluation of the posterior.
Args:
posterior_nn: Neural net representing the posterior.
accept_reject_fn: Function that evaluatuates which samples are accepted or
proposal: A callable that takes `sample_shape` as arguments (and kwargs as
needed). Returns samples from the proposal distribution with shape
(*sample_shape, event_dim).
accept_reject_fn: Function that evaluates which samples are accepted or
rejected. Must take a batch of parameters and return a boolean tensor which
indicates which parameters get accepted.
num_samples: Desired number of samples.
Expand Down

0 comments on commit d6e4133

Please sign in to comment.