Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor simulate_for_sbi location #1253

Merged
merged 4 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
NeuralInference, # noqa: F401
check_if_proposal_has_default_x,
infer,
simulate_for_sbi,
)
from sbi.inference.trainers.fmpe import FMPE
from sbi.inference.trainers.nle import MNLE, NLE_A
Expand Down Expand Up @@ -48,3 +47,4 @@
posterior_estimator_based_potential,
ratio_estimator_based_potential,
)
from sbi.utils.simulation_utils import simulate_for_sbi
115 changes: 3 additions & 112 deletions sbi/inference/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
from warnings import warn

import numpy as np
import torch
from joblib import Parallel, delayed
from numpy import ndarray
from torch import Tensor, float32
from torch import Tensor
from torch.distributions import Distribution
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.auto import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.utils import (
Expand All @@ -29,7 +25,8 @@
validate_theta_and_x,
warn_if_zscoring_changes_data,
)
from sbi.utils.sbiutils import get_simulations_since_round, seed_all_backends
from sbi.utils.sbiutils import get_simulations_since_round
from sbi.utils.simulation_utils import simulate_for_sbi
from sbi.utils.torchutils import check_if_prior_on_device, process_device
from sbi.utils.user_input_checks import (
check_sbi_inputs,
Expand Down Expand Up @@ -568,112 +565,6 @@ def __setstate__(self, state_dict: Dict):
self.__dict__ = state_dict


# Refactoring following #1175. tl:dr: letting joblib iterate over numpy arrays
# allows for a roughly 10x performance gain. The resulting casting necessity
# (cfr. user_input_checks.wrap_as_joblib_efficient_simulator) introduces
# considerable overhead. The simulating pipeline should, therefore, be further
# restructured in the future (PR #1188).
def simulate_for_sbi(
simulator: Callable,
proposal: Any,
num_simulations: int,
num_workers: int = 1,
simulation_batch_size: Union[int, None] = 1,
seed: Optional[int] = None,
show_progress_bar: bool = True,
) -> Tuple[Tensor, Tensor]:
r"""Returns ($\theta, x$) pairs obtained from sampling the proposal and simulating.

This function performs two steps:

- Sample parameters $\theta$ from the `proposal`.
- Simulate these parameters to obtain $x$.

Args:
simulator: A function that takes parameters $\theta$ and maps them to
simulations, or observations, `x`, $\text{sim}(\theta)\to x$. Any
regular Python callable (i.e. function or class with `__call__` method)
can be used. Note that the simulator should be able to handle numpy
arrays for efficient parallelization. You can use
`process_simulator` to ensure this.
proposal: Probability distribution that the parameters $\theta$ are sampled
from.
num_simulations: Number of simulations that are run.
num_workers: Number of parallel workers to use for simulations.
simulation_batch_size: Number of parameter sets of shape
(simulation_batch_size, parameter_dimension) that the simulator
receives per call. If None, we set
simulation_batch_size=num_simulations and simulate all parameter
sets with one call. Otherwise, we construct batches of parameter
sets and distribute them among num_workers.
seed: Seed for reproducibility.
show_progress_bar: Whether to show a progress bar for simulating. This will not
affect whether there will be a progressbar while drawing samples from the
proposal.

Returns: Sampled parameters $\theta$ and simulation-outputs $x$.
"""

if num_simulations == 0:
theta = torch.tensor([], dtype=float32)
x = torch.tensor([], dtype=float32)

else:
# Cast theta to numpy for better joblib performance (seee #1175)
seed_all_backends(seed)
theta = proposal.sample((num_simulations,))

# Parse the simulation_batch_size logic
if simulation_batch_size is None:
simulation_batch_size = num_simulations
else:
simulation_batch_size = min(simulation_batch_size, num_simulations)

if num_workers != 1:
# For multiprocessing, we want to switch to numpy arrays.
# The batch size will be an approximation, since np.array_split does
# not take as argument the size of the batch but their total.
num_batches = num_simulations // simulation_batch_size
batches = np.array_split(theta.numpy(), num_batches, axis=0)
batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),))

# define seeded simulator.
def simulator_seeded(theta: ndarray, seed: int) -> Tensor:
seed_all_backends(seed)
return simulator(theta)

try: # catch TypeError to give more informative error message
simulation_outputs: list[Tensor] = [ # pyright: ignore
xx
for xx in tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(simulator_seeded)(batch, seed)
for batch, seed in zip(batches, batch_seeds)
),
total=num_simulations,
disable=not show_progress_bar,
)
]
except TypeError as err:
raise TypeError(
"For multiprocessing, we switch to numpy arrays. Make sure to "
"preprocess your simulator with `process_simulator` to handle numpy"
" arrays."
) from err

else:
simulation_outputs: list[Tensor] = []
batches = torch.split(theta, simulation_batch_size)
for batch in tqdm(batches, disable=not show_progress_bar):
simulation_outputs.append(simulator(batch))

# Correctly format the output
x = torch.cat(simulation_outputs, dim=0)
theta = torch.as_tensor(theta, dtype=float32)

return theta, x


def check_if_proposal_has_default_x(proposal: Any):
"""Check for validity of the provided proposal distribution.

Expand Down
119 changes: 119 additions & 0 deletions sbi/utils/simulation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Optional, Tuple, Union

import numpy as np
import torch
from joblib import Parallel, delayed
from numpy import ndarray
from torch import Tensor, float32
from tqdm.auto import tqdm

from sbi.utils.sbiutils import seed_all_backends


# Refactoring following #1175. tl:dr: letting joblib iterate over numpy arrays
# allows for a roughly 10x performance gain. The resulting casting necessity
# (cfr. user_input_checks.wrap_as_joblib_efficient_simulator) introduces
# considerable overhead. The simulating pipeline should, therefore, be further
# restructured in the future (PR #1188).
def simulate_for_sbi(
simulator: Callable,
proposal: Any,
num_simulations: int,
num_workers: int = 1,
simulation_batch_size: Union[int, None] = 1,
seed: Optional[int] = None,
show_progress_bar: bool = True,
) -> Tuple[Tensor, Tensor]:
r"""Returns ($\theta, x$) pairs obtained from sampling the proposal and simulating.

This function performs two steps:

- Sample parameters $\theta$ from the `proposal`.
- Simulate these parameters to obtain $x$.

Args:
simulator: A function that takes parameters $\theta$ and maps them to
simulations, or observations, `x`, $\text{sim}(\theta)\to x$. Any
regular Python callable (i.e. function or class with `__call__` method)
can be used. Note that the simulator should be able to handle numpy
arrays for efficient parallelization. You can use
`process_simulator` to ensure this.
proposal: Probability distribution that the parameters $\theta$ are sampled
from.
num_simulations: Number of simulations that are run.
num_workers: Number of parallel workers to use for simulations.
simulation_batch_size: Number of parameter sets of shape
(simulation_batch_size, parameter_dimension) that the simulator
receives per call. If None, we set
simulation_batch_size=num_simulations and simulate all parameter
sets with one call. Otherwise, we construct batches of parameter
sets and distribute them among num_workers.
seed: Seed for reproducibility.
show_progress_bar: Whether to show a progress bar for simulating. This will not
affect whether there will be a progressbar while drawing samples from the
proposal.

Returns: Sampled parameters $\theta$ and simulation-outputs $x$.
"""

if num_simulations == 0:
theta = torch.tensor([], dtype=float32)
x = torch.tensor([], dtype=float32)

else:
# Cast theta to numpy for better joblib performance (seee #1175)
seed_all_backends(seed)
theta = proposal.sample((num_simulations,))

# Parse the simulation_batch_size logic
if simulation_batch_size is None:
simulation_batch_size = num_simulations
else:
simulation_batch_size = min(simulation_batch_size, num_simulations)

if num_workers != 1:
# For multiprocessing, we want to switch to numpy arrays.
# The batch size will be an approximation, since np.array_split does
# not take as argument the size of the batch but their total.
num_batches = num_simulations // simulation_batch_size
batches = np.array_split(theta.numpy(), num_batches, axis=0)
batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),))

# define seeded simulator.
def simulator_seeded(theta: ndarray, seed: int) -> Tensor:
seed_all_backends(seed)
return simulator(theta)

try: # catch TypeError to give more informative error message
simulation_outputs: list[Tensor] = [ # pyright: ignore
xx
for xx in tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(simulator_seeded)(batch, seed)
for batch, seed in zip(batches, batch_seeds)
),
total=num_simulations,
disable=not show_progress_bar,
)
]
except TypeError as err:
raise TypeError(
"For multiprocessing, we switch to numpy arrays. Make sure to "
"preprocess your simulator with `process_simulator` to handle numpy"
" arrays."
) from err

else:
simulation_outputs: list[Tensor] = []
batches = torch.split(theta, simulation_batch_size)
for batch in tqdm(batches, disable=not show_progress_bar):
simulation_outputs.append(simulator(batch))

# Correctly format the output
x = torch.cat(simulation_outputs, dim=0)
theta = torch.as_tensor(theta, dtype=float32)

return theta, x
72 changes: 72 additions & 0 deletions tests/user_input_checks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Callable, Tuple

import numpy as np
import pytest
import torch
from pyknos.mdn.mdn import MultivariateGaussianMDN
Expand All @@ -13,6 +14,7 @@

from sbi.inference import NPE_A, NPE_C, simulate_for_sbi
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.simulators import linear_gaussian
from sbi.simulators.linear_gaussian import diagonal_linear_gaussian
from sbi.utils import mcmc_transform, within_support
from sbi.utils.torchutils import BoxUniform
Expand Down Expand Up @@ -501,3 +503,73 @@ def test_passing_custom_density_estimator(arg):
density_estimator = arg
prior = MultivariateNormal(torch.zeros(2), torch.eye(2))
_ = NPE_C(prior=prior, density_estimator=density_estimator)


@pytest.mark.parametrize(
"num_simulations, simulation_batch_size, num_workers, \
use_process_simulator",
[
(0, None, 1, True),
(10, None, 1, True),
(100, 10, 1, True),
(100, None, 2, True),
(1000, 50, 4, True),
(100, 10, 2, False),
],
)
def test_simulate_for_sbi(
num_simulations, simulation_batch_size, num_workers, use_process_simulator
):
"""Test the simulate_for_sbi function with various configurations."""
num_dim = 3
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)
prior = BoxUniform(-2.0 * ones(num_dim), 2.0 * ones(num_dim))

def failing_simulator(theta):
if isinstance(theta, np.ndarray):
raise TypeError("This simulator does not support numpy arrays.")
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

simulator = (
failing_simulator
if not use_process_simulator
else process_simulator(failing_simulator, prior, False)
)

if num_simulations == 0:
theta, x = simulate_for_sbi(
simulator=simulator,
proposal=prior,
num_simulations=num_simulations,
simulation_batch_size=simulation_batch_size,
num_workers=num_workers,
)
assert (
theta.numel() == 0
), "Theta should be an empty tensor when num_simulations=0"
assert x.numel() == 0, "x should be an empty tensor when num_simulations=0"
else:
if not use_process_simulator and num_workers > 1:
with pytest.raises(TypeError, match="For multiprocessing"):
simulate_for_sbi(
simulator=simulator,
proposal=prior,
num_simulations=num_simulations,
simulation_batch_size=simulation_batch_size,
num_workers=num_workers,
)
else:
theta, x = simulate_for_sbi(
simulator=simulator,
proposal=prior,
num_simulations=num_simulations,
simulation_batch_size=simulation_batch_size,
num_workers=num_workers,
)
assert (
theta.shape[0] == num_simulations
), "Theta should have num_simulations rows"
assert x.shape[0] == num_simulations, "x should have num_simulations rows"
assert theta.shape[1] == num_dim, "Theta should have num_dim columns"
assert x.shape[1] == num_dim, "x should have num_dim columns"