Skip to content

Commit

Permalink
npse MAP
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Jan 18, 2025
1 parent 0ab3a14 commit 53bd4f9
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 50 deletions.
6 changes: 3 additions & 3 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def sample(
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -176,7 +176,7 @@ def sample_batched(
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -373,7 +373,7 @@ def leakage_correction(
def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
return rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
show_progress_bars=show_progress_bars,
Expand Down
129 changes: 112 additions & 17 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from typing import Dict, Optional, Union

import torch
Expand All @@ -16,9 +17,11 @@
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
)
from sbi.samplers.rejection import rejection
from sbi.samplers.score import Corrector, Diffuser, Predictor
from sbi.sbi_types import Shape
from sbi.utils import check_prior
from sbi.utils.sbiutils import gradient_ascent, within_support
from sbi.utils.torchutils import ensure_theta_batched


Expand Down Expand Up @@ -136,8 +139,18 @@ def sample(
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
self.potential_fn.set_x(x)

num_samples = torch.Size(sample_shape).numel()

if self.sample_with == "ode":
samples = self.sample_via_zuko(sample_shape=sample_shape, x=x)
samples = rejection.accept_reject_sample(
proposal=self.sample_via_zuko,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"x": x},
)[0]
samples = samples.reshape(sample_shape + self.score_estimator.input_shape)
elif self.sample_with == "sde":
samples = self._sample_via_diffusion(
sample_shape=sample_shape,
Expand All @@ -150,6 +163,25 @@ def sample(
max_sampling_batch_size=max_sampling_batch_size,
show_progress_bars=show_progress_bars,
)
proposal_sampling_kwargs = {
"predictor": predictor,
"corrector": corrector,
"predictor_params": predictor_params,
"corrector_params": corrector_params,
"steps": steps,
"ts": ts,
"max_sampling_batch_size": max_sampling_batch_size,
"show_progress_bars": show_progress_bars,
}
samples = rejection.accept_reject_sample(
proposal=self._sample_via_diffusion,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs=proposal_sampling_kwargs,
)[0]
samples = samples.reshape(sample_shape + self.score_estimator.input_shape)

return samples

Expand Down Expand Up @@ -220,12 +252,12 @@ def _sample_via_diffusion(
)
samples = torch.cat(samples, dim=0)[:num_samples]

return samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def sample_via_zuko(
self,
x: Tensor,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
) -> Tensor:
r"""Return samples from posterior distribution with probability flow ODE.
Expand All @@ -241,10 +273,13 @@ def sample_via_zuko(
"""
num_samples = torch.Size(sample_shape).numel()

x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)

flow = self.potential_fn.get_continuous_normalizing_flow(condition=x)
samples = flow.sample(torch.Size((num_samples,)))

return samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def log_prob(
self,
Expand Down Expand Up @@ -301,7 +336,7 @@ def map(
x: Optional[Tensor] = None,
num_iter: int = 1000,
num_to_optimize: int = 1000,
learning_rate: float = 1e-5,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1000,
save_best_every: int = 1000,
Expand Down Expand Up @@ -349,17 +384,77 @@ def map(
Returns:
The MAP estimate.
"""
raise NotImplementedError(
"MAP estimation is currently not working accurately for ScorePosterior."
if x is not None:
raise ValueError(
"Passing `x` directly to `.map()` has been deprecated."
"Use `.self_default_x()` to set `x`, and then run `.map()` "
)

if self.default_x is None:
raise ValueError(
"Default `x` has not been set."
"To set the default, use the `.set_default_x()` method."
)

if self._map is None or force_update:
self.potential_fn.set_x(self.default_x)
callable_potential_fn = CallableDifferentiablePotentialFunction(
self.potential_fn
)
if init_method == "posterior":
inits = self.sample((num_init_samples,))
elif init_method == "proposal":
inits = self.proposal.sample((num_init_samples,)) # type: ignore
elif isinstance(init_method, Tensor):
inits = init_method
else:
raise ValueError

self._map = gradient_ascent(
potential_fn=callable_potential_fn,
inits=inits,
theta_transform=self.theta_transform,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
)[0]

return self._map


class DifferentiablePotentialFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, call_function, gradient_function):
# Save the methods as callables
ctx.call_function = call_function
ctx.gradient_function = gradient_function
ctx.save_for_backward(input)

# Perform the forward computation
output = call_function(input)
return output

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad = ctx.gradient_function(input)
while len(grad_output.shape) < len(grad.shape):
grad_output = grad_output.unsqueeze(-1)
grad_input = grad_output * grad
return grad_input, None, None


# Wrapper class to manage state
class CallableDifferentiablePotentialFunction:
def __init__(self, posterior_score_based_potential):
self.posterior_score_based_potential = posterior_score_based_potential

def __call__(self, input):
prepared_potential = partial(
self.posterior_score_based_potential.__call__, rebuild_flow=False
)
return super().map(
x=x,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
force_update=force_update,
return DifferentiablePotentialFunction.apply(
input, prepared_potential, self.posterior_score_based_potential.gradient
)
57 changes: 39 additions & 18 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def score_estimator_based_potential(
score_estimator: ConditionalScoreEstimator,
prior: Optional[Distribution],
x_o: Optional[Tensor],
enable_transform: bool = False,
enable_transform: bool = True,
) -> Tuple["PosteriorScoreBasedPotential", TorchTransform]:
r"""Returns the potential function gradient for score estimators.
Expand All @@ -41,10 +41,6 @@ def score_estimator_based_potential(
score_estimator, prior, x_o, device=device
)

assert (
enable_transform is False
), "Transforms are not yet supported for score estimators."

if prior is not None:
theta_transform = mcmc_transform(
prior, device=device, enable_transform=enable_transform
Expand Down Expand Up @@ -74,16 +70,38 @@ def __init__(
`iid_bridge` as proposed in Geffner et al. is implemented.
device: The device on which to evaluate the potential.
"""

super().__init__(prior, x_o, device=device)
self.score_estimator = score_estimator
self.score_estimator.eval()
self.iid_method = iid_method
super().__init__(prior, x_o, device=device)

def set_x(
self,
x_o: Optional[Tensor],
x_is_iid: Optional[bool] = False,
rebuild_flow: Optional[bool] = True,
):
super().set_x(x_o, x_is_iid)
if rebuild_flow and self._x_o is not None:
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert (
x_density_estimator.shape[0] == 1
), "PosteriorScoreBasedPotential supports only x batchsize of 1`."
# For large number of evals, we want a high-tolerance flow.
# This flow will be used mainly for MAP calculations, hence we want to save
# it instead of rebuilding it every time.
flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=1e-2, rtol=1e-3, exact=True
)
self.flow = flow

def __call__(
self,
theta: Tensor,
track_gradients: bool = True,
rebuild_flow: bool = True,
atol: float = 1e-5,
rtol: float = 1e-6,
exact: bool = True,
Expand All @@ -93,6 +111,7 @@ def __call__(
Args:
theta: The parameters at which to evaluate the potential.
track_gradients: Whether to track gradients.
rebuild_flow: Whether to rebuild the CNF for accurate log_prob evaluation.
atol: Absolute tolerance for the ODE solver.
rtol: Relative tolerance for the ODE solver.
exact: Whether to use the exact ODE solver.
Expand All @@ -104,18 +123,20 @@ def __call__(
theta_density_estimator = reshape_to_sample_batch_event(
theta, theta.shape[1:], leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert (
x_density_estimator.shape[0] == 1
), "PosteriorScoreBasedPotential supports only x batchsize of 1`."

self.score_estimator.eval()
if rebuild_flow or self.flow is None:
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert (
x_density_estimator.shape[0] == 1
), "PosteriorScoreBasedPotential supports only x batchsize of 1`."

flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
)
flow = self.get_continuous_normalizing_flow(
condition=x_density_estimator, atol=atol, rtol=rtol, exact=exact
)
else:
flow = self.flow

with torch.set_grad_enabled(track_gradients):
log_probs = flow.log_prob(theta_density_estimator).squeeze(-1)
Expand All @@ -135,7 +156,7 @@ def gradient(
r"""Returns the potential function gradient for score-based methods.
Args:
theta: The parameters at which to evaluate the potential.
theta: The parameters at which to evaluate the potential gradient.
time: The diffusion time. If None, then `t_min` of the
self.score_estimator is used (i.e. we evaluate the gradient of the
actual data distribution).
Expand Down
9 changes: 4 additions & 5 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple

import torch
import torch.distributions.transforms as torch_tf
from torch import Tensor, as_tensor, nn
from torch.distributions import Distribution
from torch import Tensor, as_tensor
from tqdm.auto import tqdm

from sbi.utils.sbiutils import gradient_ascent
Expand Down Expand Up @@ -188,7 +187,7 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:

@torch.no_grad()
def accept_reject_sample(
proposal: Union[nn.Module, Distribution],
proposal: Callable,
accept_reject_fn: Callable,
num_samples: int,
show_progress_bars: bool = False,
Expand Down Expand Up @@ -278,7 +277,7 @@ def accept_reject_sample(
num_samples_possible = 0
while num_remaining > 0:
# Sample and reject.
candidates = proposal.sample(
candidates = proposal(
(sampling_batch_size,), # type: ignore
**proposal_sampling_kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def sample(

if sample_with == "rejection":
samples, acceptance_rate = accept_reject_sample(
proposal=self._prior,
proposal=self._prior.sample,
accept_reject_fn=self._accept_reject_fn,
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down
8 changes: 2 additions & 6 deletions tests/linearGaussian_npse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def simulator(theta):

@pytest.mark.xfail(
reason="iid_bridge not working.",
raises=NotImplementedError,
raises=AssertionError,
strict=True,
match="Score accumulation*",
)
Expand Down Expand Up @@ -203,10 +203,6 @@ def test_npse_iid_inference(num_trials):


@pytest.mark.slow
@pytest.mark.xfail(
raises=NotImplementedError,
reason="MAP optimization via score not working accurately.",
)
def test_npse_map():
num_dim = 2
x_o = zeros(num_dim)
Expand Down Expand Up @@ -234,4 +230,4 @@ def test_npse_map():

map_ = posterior.map(show_progress_bars=True)

assert torch.allclose(map_, gt_posterior.mean, atol=0.2), "MAP is not close to GT."
assert torch.allclose(map_, gt_posterior.mean, atol=0.4), "MAP is not close to GT."

0 comments on commit 53bd4f9

Please sign in to comment.