Skip to content

Commit

Permalink
build functions return density estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Mar 7, 2024
1 parent 17f3033 commit 5440446
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
22 changes: 13 additions & 9 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from sbi.utils.torchutils import create_alternating_binary_mask
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device

from sbi.neural_nets.density_estimators import NFlowsFlow

def build_made(
batch_x: Tensor,
Expand All @@ -30,7 +30,7 @@ def build_made(
num_mixture_components: int = 10,
embedding_net: nn.Module = nn.Identity(),
**kwargs,
) -> nn.Module:
) -> NFlowsFlow:
"""Builds MADE p(x|y).
Args:
Expand Down Expand Up @@ -90,8 +90,9 @@ def build_made(
)

neural_net = flows.Flow(transform, distribution, embedding_net)
flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape)

return neural_net
return flow


def build_maf(
Expand All @@ -106,7 +107,7 @@ def build_maf(
dropout_probability: float = 0.0,
use_batch_norm: bool = False,
**kwargs,
) -> nn.Module:
) -> NFlowsFlow:
"""Builds MAF p(x|y).
Args:
Expand Down Expand Up @@ -176,8 +177,9 @@ def build_maf(

distribution = distributions_.StandardNormal((x_numel,))
neural_net = flows.Flow(transform, distribution, embedding_net)
flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape)

return neural_net
return flow


def build_maf_rqs(
Expand All @@ -198,7 +200,7 @@ def build_maf_rqs(
min_bin_height: float = rational_quadratic.DEFAULT_MIN_BIN_HEIGHT,
min_derivative: float = rational_quadratic.DEFAULT_MIN_DERIVATIVE,
**kwargs,
) -> nn.Module:
) -> NFlowsFlow:
"""Builds MAF p(x|y), where the diffeomorphisms are rational-quadratic
splines (RQS).
Expand Down Expand Up @@ -286,8 +288,9 @@ def build_maf_rqs(

distribution = distributions_.StandardNormal((x_numel,))
neural_net = flows.Flow(transform, distribution, embedding_net)
flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape)

return neural_net
return flow


def build_nsf(
Expand All @@ -305,7 +308,7 @@ def build_nsf(
dropout_probability: float = 0.0,
use_batch_norm: bool = False,
**kwargs,
) -> nn.Module:
) -> NFlowsFlow:
"""Builds NSF p(x|y).
Args:
Expand Down Expand Up @@ -407,8 +410,9 @@ def mask_in_layer(i):
# Combine transforms.
transform = transforms.CompositeTransform(transform_list)
neural_net = flows.Flow(transform, distribution, embedding_net)
flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape)

return neural_net
return flow


class ContextSplineMap(nn.Module):
Expand Down
7 changes: 4 additions & 3 deletions sbi/neural_nets/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import sbi.utils as utils
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device

from sbi.neural_nets.density_estimators import NFlowsFlow

def build_mdn(
batch_x: Tensor,
Expand All @@ -20,7 +20,7 @@ def build_mdn(
num_components: int = 10,
embedding_net: nn.Module = nn.Identity(),
**kwargs,
) -> nn.Module:
) -> NFlowsFlow:
"""Builds MDN p(x|y).
Args:
Expand Down Expand Up @@ -80,5 +80,6 @@ def build_mdn(
)

neural_net = flows.Flow(transform, distribution, embedding_net)
flow = NFlowsFlow(neural_net,condition_shape = embedding_net(batch_y[0]).shape)

return neural_net
return flow

0 comments on commit 5440446

Please sign in to comment.