diff --git a/sbi/neural_nets/embedding_nets.py b/sbi/neural_nets/embedding_nets.py index a702f8864..5069fd8e2 100644 --- a/sbi/neural_nets/embedding_nets.py +++ b/sbi/neural_nets/embedding_nets.py @@ -1,6 +1,8 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . +from typing import List, Tuple, Union + import torch from torch import Tensor, nn @@ -44,81 +46,198 @@ def forward(self, x: Tensor) -> Tensor: return self.net(x) +def calculate_filter_output_size(input_size, padding, dilation, kernel, stride) -> int: + """Returns output size of a filter given filter arguments. + + Uses formulas from https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html. + """ + + return int( + (int(input_size) + 2 * int(padding) - int(dilation) * (int(kernel) - 1) - 1) + / int(stride) + + 1 + ) + + +def get_new_cnn_output_size( + input_shape: Tuple, + conv_layer: Union[nn.Conv1d, nn.Conv2d], + pool: Union[nn.MaxPool1d, nn.MaxPool2d], +) -> Union[Tuple[int], Tuple[int, int]]: + """Returns new output size after applying a given convolution and pooling. + + Assumes quadratic input dimensions of the data and the applied kernels, e.g., + input_dim refers to data of shape (input_dim, input_dim) and all convolutions should + use quadratic kernel sizes. + + Args: + input_shape: tup. + conv_layer: applied convolutional layers + pool: applied pooling layer + + Returns: + new output dimension of the cnn layer. + + """ + assert isinstance(input_shape, Tuple) + assert 0 < len(input_shape) < 3 + assert isinstance(conv_layer.padding, (Tuple, int)) + assert isinstance(pool.padding, (Tuple, int)) + + # for 1D inputs or quadratic kernels only one dimension applies + if len(input_shape) == 1 or len(conv_layer.kernel_size) == 1: + + if len(input_shape) > 1: + assert input_shape[0] == input_shape[1], "this case requires square input." + dim_after_conv = calculate_filter_output_size( + input_shape[0], + conv_layer.padding[0], + conv_layer.dilation[0], + conv_layer.kernel_size[0], + conv_layer.stride[0], + ) + dim_after_pool = calculate_filter_output_size( + dim_after_conv, pool.padding, pool.dilation, pool.kernel_size, pool.stride + ) + + # return two entries of 2D input. + return ( + (dim_after_pool,) + if len(input_shape) == 1 + else (dim_after_pool, dim_after_pool) + ) + # for rectangular 2D input or kernels both dimensions have to be calculated. + else: + assert len(conv_layer.padding) > 1 + assert len(conv_layer.dilation) > 1 + assert len(conv_layer.kernel_size) > 1 + assert len(conv_layer.stride) > 1 + + h_out = calculate_filter_output_size( + input_shape[0], + conv_layer.padding[0], + conv_layer.dilation[0], + conv_layer.kernel_size[0], + conv_layer.stride[0], + ) + w_out = calculate_filter_output_size( + input_shape[1], + conv_layer.padding[1], + conv_layer.dilation[1], + conv_layer.kernel_size[1], + conv_layer.stride[1], + ) + h_out = calculate_filter_output_size( + h_out, + pool.padding, + pool.dilation, + pool.kernel_size, + pool.stride, + ) + w_out = calculate_filter_output_size( + w_out, + pool.padding, + pool.dilation, + pool.kernel_size, + pool.stride, + ) + return (h_out, w_out) + + class CNNEmbedding(nn.Module): def __init__( self, - input_dim: int, + input_shape: Tuple, + in_channels: int = 1, + out_channels_per_layer: List = [6, 12], + num_conv_layers: int = 2, + num_linear_layers: int = 2, + num_linear_units: int = 50, output_dim: int = 20, - num_fully_connected: int = 2, - num_hiddens: int = 120, - out_channels_cnn_1: int = 10, - out_channels_cnn_2: int = 16, kernel_size: int = 5, - pool_size=4, + pool_kernel_size: int = 2, ): - """Multi-layer (C)NN - First two layers are convolutional, followed by fully connected layers. - Performing 1d convolution and max pooling with preset configs. + """Convolutional embedding network. + First two layers are convolutional, followed by fully connected layers. + + Automatically infers whether to apply 1D or 2D convolution depending on + input_shape. + Allows usage of multiple (color) channels by passing in_channels > 1. Args: - input_dim: Dimensionality of input. - output_dim: Dimensionality of the output. - num_conv: Number of convolutional layers. - num_fully_connected: Number fully connected layer, minimum of 2. - num_hiddens: Number of hidden dimensions in fully-connected layers. - out_channels_cnn_1: Number of oputput channels for the first convolutional - layer. - out_channels_cnn_2: Number of oputput channels for the second - convolutional layer. + input_shape: Dimensionality of input, e.g., (28,) for 1D, (28, 28) for 2D. + in_channels: Number of image channels, default 1. + out_channels_per_layer: Number of out convolutional out_channels for each + layer. Must match the number of layers passed below. + num_cnn_layers: Number of convolutional layers. + num_linear_layers: Number fully connected layer. + num_linear_units: Number of hidden units in fully-connected layers. + output_dim: Number of output units of the final layer. kernel_size: Kernel size for both convolutional layers. pool_size: pool size for MaxPool1d operation after the convolutional layers. - - Remark: The implementation of the convolutional layers was not tested - rigourously. While it works for the default configuration parameters it - might cause shape conflicts fot badly chosen parameters. """ - super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.num_hiddens = num_hiddens - - # construct convolutional-pooling subnet - pool = nn.MaxPool1d(pool_size) - conv_layers = [ - nn.Conv1d(1, out_channels_cnn_1, kernel_size, padding="same"), - nn.ReLU(), - pool, - nn.Conv1d( - out_channels_cnn_1, out_channels_cnn_2, kernel_size, padding="same" - ), - nn.ReLU(), - pool, - ] - self.conv_subnet = nn.Sequential(*conv_layers) + super(CNNEmbedding, self).__init__() - # construct fully connected layers - input_dim_fc = out_channels_cnn_2 * (int(input_dim / out_channels_cnn_2)) + assert isinstance( + input_shape, Tuple + ), "input_shape must be a Tuple of size 1 or 2, e.g., (width, [height])." + assert ( + 0 < len(input_shape) < 3 + ), """input_shape must be a Tuple of size 1 or 2, e.g., + (width, [height]). Number of input channels are passed separately""" - self.fc_subnet = FCEmbedding( - input_dim=input_dim_fc, + use_2d_cnn = len(input_shape) == 2 + conv_module = nn.Conv2d if use_2d_cnn else nn.Conv1d + pool_module = nn.MaxPool2d if use_2d_cnn else nn.MaxPool1d + + assert ( + len(out_channels_per_layer) == num_conv_layers + ), "out_channels needs as many entries as num_cnn_layers." + + # define input shape with channel + self.input_shape = (in_channels, *input_shape) + + # Construct CNN feature extractor. + cnn_layers = [] + cnn_output_size = input_shape + stride = 1 + padding = 1 + for ii in range(num_conv_layers): + # Defining another 2D convolution layer + conv_layer = conv_module( + in_channels=in_channels if ii == 0 else out_channels_per_layer[ii - 1], + out_channels=out_channels_per_layer[ii], + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + pool = pool_module(kernel_size=pool_kernel_size) + cnn_layers += [conv_layer, nn.ReLU(inplace=True), pool] + # Calculate change of output size of each CNN layer + cnn_output_size = get_new_cnn_output_size(cnn_output_size, conv_layer, pool) + + self.cnn_subnet = nn.Sequential(*cnn_layers) + + # Construct linear post processing net. + self.linear_subnet = FCEmbedding( + input_dim=out_channels_per_layer[-1] + * torch.prod(torch.tensor(cnn_output_size)), output_dim=output_dim, - num_layers=num_fully_connected, - num_hiddens=num_hiddens, + num_layers=num_linear_layers, + num_hiddens=num_linear_units, ) + # Defining the forward pass def forward(self, x: Tensor) -> Tensor: - """Network forward pass. - Args: - x: Input tensor (batch_size, input_dim) - Returns: - Network output (batch_size, output_dim). - """ - x = self.conv_subnet(x.unsqueeze(1)) - x = torch.flatten(x, 1) # flatten all dimensions except batch - embedding = self.fc_subnet(x) + batch_size = x.size(0) - return embedding + # reshape to account for single channel data. + x = self.cnn_subnet(x.view(batch_size, *self.input_shape)) + # flatten for linear layers. + x = x.view(batch_size, -1) + x = self.linear_subnet(x) + return x class PermutationInvariantEmbedding(nn.Module): diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index e05cf742e..cc8f92d95 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -2,12 +2,15 @@ import pytest import torch - from torch import eye, ones, zeros from sbi import utils as utils from sbi.inference import SNLE, SNPE, SNRE -from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding +from sbi.neural_nets.embedding_nets import ( + CNNEmbedding, + FCEmbedding, + PermutationInvariantEmbedding, +) from sbi.simulators.linear_gaussian import ( linear_gaussian, true_posterior_linear_gaussian_mvn_prior, @@ -179,3 +182,58 @@ def test_iid_inference(num_trials, num_dim, method): check_c2st(samples, reference_samples, alg=method + " permuted") else: check_c2st(samples, reference_samples, alg=method) + + +@pytest.mark.parametrize( + "input_shape", + [ + (32,), + (32, 32), + (32, 64), + ], +) +@pytest.mark.parametrize("num_channels", (1, 3)) +def test_1d_and_2d_cnn_embedding_net(input_shape, num_channels): + import torch + from torch.distributions import MultivariateNormal + + estimator_provider = posterior_nn( + "mdn", + embedding_net=CNNEmbedding( + input_shape, in_channels=num_channels, output_dim=20 + ), + ) + + num_dim = input_shape[0] + + def simulator2d(theta): + x = MultivariateNormal( + loc=theta, covariance_matrix=0.5 * torch.eye(num_dim) + ).sample() + return x.unsqueeze(2).repeat(1, 1, input_shape[1]) + + def simulator1d(theta): + return torch.rand_like(theta) + theta + + if len(input_shape) == 1: + simulator = simulator1d + xo = torch.ones(1, num_channels, *input_shape).squeeze(1) + else: + simulator = simulator2d + xo = torch.ones(1, num_channels, *input_shape).squeeze(1) + + prior = MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim)) + + num_simulations = 1000 + theta = prior.sample((num_simulations,)) + x = simulator(theta) + if num_channels > 1: + x = x.unsqueeze(1).repeat( + 1, num_channels, *[1 for _ in range(len(input_shape))] + ) + + trainer = SNPE(prior=prior, density_estimator=estimator_provider) + trainer.append_simulations(theta, x).train(max_num_epochs=2) + posterior = trainer.build_posterior() + + posterior.sample((10,), x=xo)