Skip to content

Commit

Permalink
infer cnn shapes in loop, improve test, and docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler authored and janfb committed Nov 3, 2022
1 parent aeca010 commit a0757d4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 64 deletions.
81 changes: 20 additions & 61 deletions sbi/neural_nets/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def get_new_cnn_output_size(
) -> Union[Tuple[int], Tuple[int, int]]:
"""Returns new output size after applying a given convolution and pooling.
Assumes quadratic input dimensions of the data and the applied kernels, e.g.,
input_dim refers to data of shape (input_dim, input_dim) and all convolutions should
use quadratic kernel sizes.
Args:
input_shape: tup.
conv_layer: applied convolutional layers
Expand All @@ -79,69 +75,32 @@ def get_new_cnn_output_size(
new output dimension of the cnn layer.
"""
assert isinstance(input_shape, Tuple)
assert 0 < len(input_shape) < 3
assert isinstance(conv_layer.padding, (Tuple, int))
assert isinstance(pool.padding, (Tuple, int))

# for 1D inputs or quadratic kernels only one dimension applies
if len(input_shape) == 1 or len(conv_layer.kernel_size) == 1:

if len(input_shape) > 1:
assert input_shape[0] == input_shape[1], "this case requires square input."
dim_after_conv = calculate_filter_output_size(
input_shape[0],
conv_layer.padding[0],
conv_layer.dilation[0],
conv_layer.kernel_size[0],
conv_layer.stride[0],
)
dim_after_pool = calculate_filter_output_size(
dim_after_conv, pool.padding, pool.dilation, pool.kernel_size, pool.stride
)

# return two entries of 2D input.
return (
(dim_after_pool,)
if len(input_shape) == 1
else (dim_after_pool, dim_after_pool)
)
# for rectangular 2D input or kernels both dimensions have to be calculated.
else:
assert len(conv_layer.padding) > 1
assert len(conv_layer.dilation) > 1
assert len(conv_layer.kernel_size) > 1
assert len(conv_layer.stride) > 1

h_out = calculate_filter_output_size(
input_shape[0],
conv_layer.padding[0],
conv_layer.dilation[0],
conv_layer.kernel_size[0],
conv_layer.stride[0],
)
w_out = calculate_filter_output_size(
input_shape[1],
conv_layer.padding[1],
conv_layer.dilation[1],
conv_layer.kernel_size[1],
conv_layer.stride[1],
)
h_out = calculate_filter_output_size(
h_out,
pool.padding,
pool.dilation,
pool.kernel_size,
pool.stride,
assert isinstance(input_shape, Tuple), "input shape must be Tuple."
assert 0 < len(input_shape) < 3, "input shape must be 1 or 2d."
assert isinstance(conv_layer.padding, Tuple), "conv layer attributes must be Tuple."
assert isinstance(pool.padding, int), "pool layer attributes must be integers."

out_after_conv = [
calculate_filter_output_size(
input_shape[i],
conv_layer.padding[i],
conv_layer.dilation[i],
conv_layer.kernel_size[i],
conv_layer.stride[i],
)
w_out = calculate_filter_output_size(
w_out,
for i in range(len(input_shape))
]
out_after_pool = [
calculate_filter_output_size(
out_after_conv[i],
pool.padding,
pool.dilation,
pool.kernel_size,
pool.stride,
)
return (h_out, w_out)
for i in range(len(input_shape))
]
return tuple(out_after_pool)


class CNNEmbedding(nn.Module):
Expand Down
7 changes: 4 additions & 3 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_iid_inference(num_trials, num_dim, method):
(32, 64),
],
)
@pytest.mark.parametrize("num_channels", (1, 3))
@pytest.mark.parametrize("num_channels", (1, 2, 3))
def test_1d_and_2d_cnn_embedding_net(input_shape, num_channels):
import torch
from torch.distributions import MultivariateNormal
Expand Down Expand Up @@ -234,6 +234,7 @@ def simulator1d(theta):

trainer = SNPE(prior=prior, density_estimator=estimator_provider)
trainer.append_simulations(theta, x).train(max_num_epochs=2)
posterior = trainer.build_posterior()
posterior = trainer.build_posterior().set_default_x(xo)

posterior.sample((10,), x=xo)
s = posterior.sample((10,))
posterior.potential(s)

0 comments on commit a0757d4

Please sign in to comment.