diff --git a/sbi/neural_nets/embedding_nets.py b/sbi/neural_nets/embedding_nets.py index a702f8864..dfd3e823f 100644 --- a/sbi/neural_nets/embedding_nets.py +++ b/sbi/neural_nets/embedding_nets.py @@ -1,7 +1,8 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . -import torch +from typing import List, Tuple, Union + from torch import Tensor, nn @@ -44,81 +45,142 @@ def forward(self, x: Tensor) -> Tensor: return self.net(x) -class CNNEmbedding(nn.Module): +def get_new_cnn_output_size( + input_dim: int, + conv_layer: Union[nn.Conv1d, nn.Conv2d], + pool: Union[nn.MaxPool1d, nn.MaxPool2d], +) -> int: + """Returns new output size after applying a given convolution and pooling. + + Assumes quadratic input dimensions of the data and the applied kernels, e.g., + input_dim refers to data of shape (input_dim, input_dim) and all convolutions should + use quadratic kernel sizes. + + Uses formulas https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html. + + Args: + input_dim: input size, e.g., output size of previous layer. + conv_layer: applied convolutional layers + pool: applied pooling layer + + Returns: + new output dimension of the cnn layer. + + """ + dim_after_conv = ( + input_dim + + 2 * conv_layer.padding[0] + - conv_layer.dilation[0] * (conv_layer.kernel_size[0] - 1) + - 1 + ) / conv_layer.stride[0] + 1 + dim_after_pool = ( + dim_after_conv + 2 * pool.padding - pool.dilation * (pool.kernel_size - 1) - 1 + ) / pool.stride + 1 + return int(dim_after_pool) + + +class CNNEmbedding2D(nn.Module): def __init__( self, - input_dim: int, + input_shape: Tuple, + in_channels: int = 1, + out_channels_per_layer: List = [6, 12], + num_conv_layers: int = 2, + num_linear_layers: int = 2, + num_linear_units: int = 50, output_dim: int = 20, - num_fully_connected: int = 2, - num_hiddens: int = 120, - out_channels_cnn_1: int = 10, - out_channels_cnn_2: int = 16, kernel_size: int = 5, - pool_size=4, + pool_kernel_size: int = 2, ): - """Multi-layer (C)NN - First two layers are convolutional, followed by fully connected layers. - Performing 1d convolution and max pooling with preset configs. + """Convolutional embedding network. + First two layers are convolutional, followed by fully connected layers. + + Automatically infers whether to apply 1D or 2D convolution depending on + input_shape. + Allows usage of multiple (color) channels by passing in_channels > 1. Args: - input_dim: Dimensionality of input. - output_dim: Dimensionality of the output. - num_conv: Number of convolutional layers. - num_fully_connected: Number fully connected layer, minimum of 2. - num_hiddens: Number of hidden dimensions in fully-connected layers. - out_channels_cnn_1: Number of oputput channels for the first convolutional - layer. - out_channels_cnn_2: Number of oputput channels for the second - convolutional layer. + input_shape: Dimensionality of input, e.g., (28,) for 1D, (28, 28) for 2D. + in_channels: Number of image channels, default 1. + out_channels_per_layer: Number of out convolutional out_channels for each + layer. Must match the number of layers passed below. + num_cnn_layers: Number of convolutional layers. + num_linear_layers: Number fully connected layer. + num_linear_units: Number of hidden units in fully-connected layers. + output_dim: Number of output units of the final layer. kernel_size: Kernel size for both convolutional layers. pool_size: pool size for MaxPool1d operation after the convolutional layers. - - Remark: The implementation of the convolutional layers was not tested - rigourously. While it works for the default configuration parameters it - might cause shape conflicts fot badly chosen parameters. """ - super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.num_hiddens = num_hiddens - - # construct convolutional-pooling subnet - pool = nn.MaxPool1d(pool_size) - conv_layers = [ - nn.Conv1d(1, out_channels_cnn_1, kernel_size, padding="same"), - nn.ReLU(), - pool, - nn.Conv1d( - out_channels_cnn_1, out_channels_cnn_2, kernel_size, padding="same" - ), - nn.ReLU(), - pool, - ] - self.conv_subnet = nn.Sequential(*conv_layers) + super(CNNEmbedding2D, self).__init__() - # construct fully connected layers - input_dim_fc = out_channels_cnn_2 * (int(input_dim / out_channels_cnn_2)) + assert isinstance( + input_shape, Tuple + ), "input_shape must be a Tuple of size 1 or 2, e.g., (width, [height])." + assert ( + 0 < len(input_shape) < 3 + ), """input_shape must be a Tuple of size 1 or 2, e.g., + (width, [height]). Number of input channels are passed separately""" - self.fc_subnet = FCEmbedding( - input_dim=input_dim_fc, + use_2d_cnn = len(input_shape) == 2 + if use_2d_cnn: + conv_module = nn.Conv2d + pool_module = nn.MaxPool2d + assert ( + input_shape[0] == input_shape[1] + ), """input_shape must square e.g, (32, 32).""" + else: + conv_module = nn.Conv1d + pool_module = nn.MaxPool1d + input_shape[0] + assert ( + len(out_channels_per_layer) == num_conv_layers + ), "out_channels needs as many entries as num_cnn_layers." + + # define input shape with channel + self.input_shape = (in_channels, *input_shape) + + # Construct CNN feature extractor. + cnn_layers = [] + cnn_output_size = self.input_shape[1] + stride = 1 + padding = 1 + for ii in range(num_conv_layers): + # Defining another 2D convolution layer + conv_layer = conv_module( + in_channels=in_channels if ii == 0 else out_channels_per_layer[ii - 1], + out_channels=out_channels_per_layer[ii], + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + pool = pool_module(kernel_size=pool_kernel_size) + cnn_layers += [conv_layer, nn.ReLU(inplace=True), pool] + # Calculate change of output size of each CNN layer + cnn_output_size = get_new_cnn_output_size(cnn_output_size, conv_layer, pool) + + self.cnn_subnet = nn.Sequential(*cnn_layers) + + # Construct linear post processing net. + self.linear_subnet = FCEmbedding( + input_dim=out_channels_per_layer[-1] * cnn_output_size * cnn_output_size + if use_2d_cnn + else out_channels_per_layer[-1] * cnn_output_size, output_dim=output_dim, - num_layers=num_fully_connected, - num_hiddens=num_hiddens, + num_layers=num_linear_layers, + num_hiddens=num_linear_units, ) - def forward(self, x: Tensor) -> Tensor: - """Network forward pass. - Args: - x: Input tensor (batch_size, input_dim) - Returns: - Network output (batch_size, output_dim). - """ - x = self.conv_subnet(x.unsqueeze(1)) - x = torch.flatten(x, 1) # flatten all dimensions except batch - embedding = self.fc_subnet(x) + # Defining the forward pass + def forward(self, x): + batch_size = x.size(0) - return embedding + # reshape to account for single channel data. + x = self.cnn_subnet(x.view(batch_size, *self.input_shape)) + # flatten for linear layers. + x = x.view(batch_size, -1) + x = self.linear_subnet(x) + return x class PermutationInvariantEmbedding(nn.Module): diff --git a/tests/embedding_net_test.py b/tests/embedding_net_test.py index e05cf742e..fc8b35530 100644 --- a/tests/embedding_net_test.py +++ b/tests/embedding_net_test.py @@ -2,12 +2,15 @@ import pytest import torch - from torch import eye, ones, zeros from sbi import utils as utils from sbi.inference import SNLE, SNPE, SNRE -from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding +from sbi.neural_nets.embedding_nets import ( + CNNEmbedding2D, + FCEmbedding, + PermutationInvariantEmbedding, +) from sbi.simulators.linear_gaussian import ( linear_gaussian, true_posterior_linear_gaussian_mvn_prior, @@ -179,3 +182,54 @@ def test_iid_inference(num_trials, num_dim, method): check_c2st(samples, reference_samples, alg=method + " permuted") else: check_c2st(samples, reference_samples, alg=method) + + +@pytest.mark.parametrize( + "input_shape", + [(32,), (32, 32)], +) +@pytest.mark.parametrize("num_channels", (1, 3)) +def test_1d_and_2d_cnn_embedding_net(input_shape, num_channels): + import torch + from torch.distributions import MultivariateNormal + + estimator_provider = posterior_nn( + "mdn", + embedding_net=CNNEmbedding2D( + input_shape, in_channels=num_channels, output_dim=20 + ), + ) + + num_dim = input_shape[0] + + def simulator2d(theta): + theta2d = theta.unsqueeze(2).tile(1, 1, theta.shape[1]) + return MultivariateNormal( + loc=theta2d, covariance_matrix=0.5 * torch.eye(num_dim) + ).sample() + + def simulator1d(theta): + return torch.rand_like(theta) + theta + + if len(input_shape) == 1: + simulator = simulator1d + xo = torch.ones(1, num_channels, num_dim).squeeze(1) + else: + simulator = simulator2d + xo = torch.ones(1, num_channels, num_dim, num_dim).squeeze(1) + + prior = MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim)) + + num_simulations = 1000 + theta = prior.sample((num_simulations,)) + x = simulator(theta) + if num_channels > 1: + x = x.unsqueeze(1).repeat( + 1, num_channels, *[1 for _ in range(len(input_shape))] + ) + + trainer = SNPE(prior=prior, density_estimator=estimator_provider) + trainer.append_simulations(theta, x).train(max_num_epochs=2) + posterior = trainer.build_posterior() + + posterior.sample((10,), x=xo)