Skip to content

Commit

Permalink
Move neural nets factory out of utils (#994)
Browse files Browse the repository at this point in the history
* Export classes and function

* Reworked factory imports

* Debugging ruff

* Debugging ruff

* Added links to the new location with deprecation warnings

* Forgot build_zuko_maf

* Added tests for the deprecated functions

* Renaming

* Doc strings now only contain the deprecation warning

* Updated the deprecation message

* Updated the deprecation message

* Ruff once more
  • Loading branch information
famura authored Mar 20, 2024
1 parent db840ee commit bae6994
Show file tree
Hide file tree
Showing 20 changed files with 439 additions and 269 deletions.
6 changes: 3 additions & 3 deletions docs/docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,17 @@

## Models

::: sbi.utils.get_nn_models.posterior_nn
::: sbi.neural_nets.factory.posterior_nn
rendering:
show_root_heading: true
show_object_full_path: true

::: sbi.utils.get_nn_models.likelihood_nn
::: sbi.neural_nets.factory.likelihood_nn
rendering:
show_root_heading: true
show_object_full_path: true

::: sbi.utils.get_nn_models.classifier_nn
::: sbi.neural_nets.factory.classifier_nn
rendering:
show_root_heading: true
show_object_full_path: true
Expand Down
5 changes: 2 additions & 3 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.tensorboard.writer import SummaryWriter

from sbi import utils as utils
from sbi.inference import NeuralInference
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.neural_nets import DensityEstimator, likelihood_nn
from sbi.utils import check_estimator_arg, check_prior, x_shape_from_simulation


Expand Down Expand Up @@ -64,7 +63,7 @@ def __init__(
# potentially for z-scoring.
check_estimator_arg(density_estimator)
if isinstance(density_estimator, str):
self._build_neural_net = utils.likelihood_nn(model=density_estimator)
self._build_neural_net = likelihood_nn(model=density_estimator)
else:
self._build_neural_net = density_estimator

Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials import posterior_estimator_based_potential
from sbi.neural_nets.density_estimators import DensityEstimator
from sbi.neural_nets import DensityEstimator, posterior_nn
from sbi.utils import (
RestrictedPrior,
check_estimator_arg,
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
# potentially for z-scoring.
check_estimator_arg(density_estimator)
if isinstance(density_estimator, str):
self._build_neural_net = utils.posterior_nn(model=density_estimator)
self._build_neural_net = posterior_nn(model=density_estimator)
else:
self._build_neural_net = density_estimator

Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snpe/snpe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def __init__(
In this codebase, we will automatically switch to the non-atomic loss if the
following criteria are fulfilled:<br/>
- proposal is a `DirectPosterior` with density_estimator `mdn`, as built
with `utils.sbi.posterior_nn()`.<br/>
with `sbi.neural_nets.posterior_nn()`.<br/>
- the density estimator is a `mdn`, as built with
`utils.sbi.posterior_nn()`.<br/>
`sbi.neural_nets.posterior_nn()`.<br/>
- `isinstance(prior, MultivariateNormal)` (from `torch.distributions`) or
`isinstance(prior, sbi.utils.BoxUniform)`
Expand Down
3 changes: 2 additions & 1 deletion sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sbi.inference.base import NeuralInference
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import ratio_estimator_based_potential
from sbi.neural_nets import classifier_nn
from sbi.utils import (
check_estimator_arg,
check_prior,
Expand Down Expand Up @@ -74,7 +75,7 @@ def __init__(
# potentially for z-scoring.
check_estimator_arg(classifier)
if isinstance(classifier, str):
self._build_neural_net = utils.classifier_nn(model=classifier)
self._build_neural_net = classifier_nn(model=classifier)
else:
self._build_neural_net = classifier

Expand Down
23 changes: 23 additions & 0 deletions sbi/neural_nets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from sbi.neural_nets.classifier import (
StandardizeInputs,
build_input_layer,
build_linear_classifier,
build_mlp_classifier,
build_resnet_classifier,
)
from sbi.neural_nets.density_estimators import DensityEstimator, NFlowsFlow
from sbi.neural_nets.embedding_nets import (
CNNEmbedding,
FCEmbedding,
PermutationInvariantEmbedding,
)
from sbi.neural_nets.factory import classifier_nn, likelihood_nn, posterior_nn
from sbi.neural_nets.flow import (
build_made,
build_maf,
build_maf_rqs,
build_nsf,
build_zuko_maf,
)
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.mnle import MixedDensityEstimator, build_mnle
Loading

0 comments on commit bae6994

Please sign in to comment.