diff --git a/setup.py b/setup.py index e105d53..6177a69 100644 --- a/setup.py +++ b/setup.py @@ -28,9 +28,6 @@ 'torchsparse/src/interpolation/devox_deterministic.cpp', 'torchsparse/src/interpolation/devox_deterministic_gpu.cu', 'torchsparse/src/interpolation/devox_cpu.cpp', - 'torchsparse/src/others/convert_neighbor_map.cpp', - 'torchsparse/src/others/convert_neighbor_map_gpu.cu', - 'torchsparse/src/others/convert_neighbor_map_cpu.cpp', 'torchsparse/src/others/count.cpp', 'torchsparse/src/others/count_gpu.cu', 'torchsparse/src/others/count_cpu.cpp', @@ -44,7 +41,6 @@ 'torchsparse/src/hash/hash_cpu.cpp', 'torchsparse/src/hashmap/hashmap_cpu.cpp', 'torchsparse/src/interpolation/devox_cpu.cpp', - 'torchsparse/src/others/convert_neighbor_map_cpu.cpp', 'torchsparse/src/others/insertion_cpu.cpp', 'torchsparse/src/others/query_cpu.cpp', 'torchsparse/src/others/count_cpu.cpp' diff --git a/torchsparse/nn/functional/__init__.py b/torchsparse/nn/functional/__init__.py index fd3baaa..94e6d43 100644 --- a/torchsparse/nn/functional/__init__.py +++ b/torchsparse/nn/functional/__init__.py @@ -1,6 +1,6 @@ from .activation import * from .conv import * -from .convert_neighbor_map import * +from .squeeze_nmap import * from .count import * from .crop import * from .devox import * diff --git a/torchsparse/nn/functional/activation.py b/torchsparse/nn/functional/activation.py index 94d9e20..5f6aa64 100644 --- a/torchsparse/nn/functional/activation.py +++ b/torchsparse/nn/functional/activation.py @@ -1,7 +1,7 @@ import functools from torch.nn import functional as F -from torchsparse.sparse_tensor import * +from torchsparse.sparse_tensor import SparseTensor __all__ = ['spact', 'sprelu', 'spleaky_relu'] diff --git a/torchsparse/nn/functional/conv.py b/torchsparse/nn/functional/conv.py index ac77136..807efc0 100644 --- a/torchsparse/nn/functional/conv.py +++ b/torchsparse/nn/functional/conv.py @@ -4,12 +4,12 @@ import torchsparse_backend from torch.autograd import Function from torch.cuda.amp import custom_fwd, custom_bwd -from torchsparse import * -from torchsparse.nn.functional.convert_neighbor_map import * -from torchsparse.nn.functional.downsample import * -from torchsparse.nn.functional.hash import * -from torchsparse.nn.functional.query import * -from torchsparse.utils.kernel_region import * +from torchsparse import SparseTensor +from torchsparse.nn import functional as spF +from torchsparse.utils.helpers import make_tuple +from torchsparse.utils.kernel import KernelRegion, KernelMapKey + +from typing import Union, List, Tuple, Optional __all__ = ['conv3d'] @@ -70,8 +70,15 @@ def backward(ctx, grad_out): features, kernel, neighbor_map, neighbor_offset, transpose = ctx.for_backwards K, c_in, c_out = kernel.size() N_in = features.size(0) - grad_features = torch.zeros(N_in, c_in, device=features.device, dtype=features.dtype) - grad_kernel = torch.zeros(K, c_in, c_out, device=kernel.device, dtype=features.dtype) + grad_features = torch.zeros(N_in, + c_in, + device=features.device, + dtype=features.dtype) + grad_kernel = torch.zeros(K, + c_in, + c_out, + device=kernel.device, + dtype=features.dtype) if 'cuda' in str(features.device): torchsparse_backend.sparseconv_backward(features, grad_features, @@ -87,18 +94,24 @@ def backward(ctx, grad_out): sparseconv_op = SpConvolution.apply -def conv3d(inputs, - kernel, - kernel_size, - bias=None, - stride=1, - dilation=1, - transpose=False): +def conv3d(inputs: SparseTensor, + kernel: torch.Tensor, + kernel_size: Union[int, List[int], Tuple[int, int, int]], + bias: Optional[torch.Tensor] = None, + stride: Union[int, List[int], Tuple[int, int, int]] = 1, + dilation: Union[int, List[int], Tuple[int, int, int]] = 1, + transpose: bool = False) -> SparseTensor: features = inputs.F coords = inputs.C cur_stride = inputs.s - if kernel_size == 1 and stride == 1 and dilation == 1: + # convert to hashable types + kernel_size = make_tuple(kernel_size) + stride = make_tuple(stride) + dilation = make_tuple(dilation) + + if kernel_size == (1, 1, 1) and stride == (1, 1, 1) and dilation == (1, 1, + 1): output_features = features.matmul(kernel) if bias is not None: output_features += bias @@ -107,34 +120,37 @@ def conv3d(inputs, output_tensor.kernel_maps = inputs.kernel_maps output_tensor.check() elif not transpose: - kernel_map = inputs.kernel_maps.get( - 'k%s_os%d_s%d_d%d' % (kernel_size, cur_stride, stride, dilation), - None) + kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride, + dilation) + kernel_map = inputs.kernel_maps.get(kernel_map_key, None) - if stride > 1: + if any(x > 1 for x in stride): # do downsample kRegion = KernelRegion(kernel_size=kernel_size, tensor_stride=cur_stride) kOffset = kRegion.get_kernel_offset().to(features.device) - new_coords = spdownsample(coords, stride * cur_stride) - hash_query = sphash(new_coords, kOffset) - hash_target = sphash(coords) - idx_query = sphashquery(hash_query, hash_target) - idx_query = list(convert_neighbor_map_gpu(idx_query)) + new_coords = spF.spdownsample(coords, stride, kernel_size, + cur_stride) + hash_query = spF.sphash(new_coords, kOffset) + hash_target = spF.sphash(coords) + idx_query = spF.sphashquery(hash_query, hash_target) + idx_query = list(spF.squeeze_nmap(idx_query)) idx_query[1] = idx_query[1].to('cpu') sizes = (features.shape[0], new_coords.shape[0]) output_features = sparseconv_op(features, kernel, idx_query[0], idx_query[1], sizes, transpose) if bias is not None: output_features += bias - output_tensor = SparseTensor(output_features, new_coords, - cur_stride * stride) + output_tensor = SparseTensor( + output_features, new_coords, + [a * b for a, b in zip(cur_stride, stride)]) output_tensor.coord_maps = copy.deepcopy(inputs.coord_maps) output_tensor.check() + + kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride, + dilation) output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps) - output_tensor.kernel_maps['k%s_os%d_s%d_d%d' % - (kernel_size, cur_stride, stride, - dilation)] = idx_query + [sizes] + output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes] else: if kernel_map is None: @@ -144,10 +160,10 @@ def conv3d(inputs, kOffset = kRegion.get_kernel_offset().to(features.device) except: raise - hash_query = sphash(coords, kOffset) - hash_target = sphash(coords) - idx_query = sphashquery(hash_query, hash_target) - idx_query = list(convert_neighbor_map_gpu(idx_query)) + hash_query = spF.sphash(coords, kOffset) + hash_target = spF.sphash(coords) + idx_query = spF.sphashquery(hash_query, hash_target) + idx_query = list(spF.squeeze_nmap(idx_query)) idx_query[1] = idx_query[1].to('cpu') sizes = (features.shape[0], features.shape[0]) output_features = sparseconv_op(features, kernel, idx_query[0], @@ -159,9 +175,9 @@ def conv3d(inputs, output_tensor.coord_maps = inputs.coord_maps output_tensor.check() output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps) - output_tensor.kernel_maps['k%s_os%d_s%d_d%d' % - (kernel_size, cur_stride, stride, - dilation)] = idx_query + [sizes] + kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride, + dilation) + output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes] else: output_features = sparseconv_op(features, kernel, kernel_map[0], kernel_map[1], @@ -176,17 +192,24 @@ def conv3d(inputs, else: # do upsample - original_stride = int(cur_stride / stride) - kernel_map = inputs.kernel_maps.get( - 'k%s_os%d_s%d_d%d' % - (kernel_size, original_stride, stride, dilation), None) + + original_stride = tuple( + [int(a / b) for a, b in zip(cur_stride, stride)]) + + kernel_map_key = KernelMapKey(kernel_size, original_stride, stride, + dilation) + kernel_map = inputs.kernel_maps.get(kernel_map_key, None) + assert kernel_map is not None, f'{kernel_map_key} does not exist.' output_features = sparseconv_op(features, kernel, kernel_map[0], kernel_map[1], kernel_map[2], transpose) if bias is not None: output_features += bias - output_tensor = SparseTensor(output_features, - inputs.coord_maps[original_stride], + + cur_coords = inputs.coord_maps.get(original_stride, None) + assert cur_coords is not None, f'{original_stride} not in coord maps.' + + output_tensor = SparseTensor(output_features, cur_coords, original_stride) output_tensor.coord_maps = inputs.coord_maps output_tensor.check() diff --git a/torchsparse/nn/functional/convert_neighbor_map.py b/torchsparse/nn/functional/convert_neighbor_map.py deleted file mode 100644 index 8389e17..0000000 --- a/torchsparse/nn/functional/convert_neighbor_map.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -import torchsparse_backend -from torch.autograd import Function - - -class ConvertNeighborMap(Function): - @staticmethod - def forward(ctx, neighbor_map): - idx_batch, idx_point = torch.where(neighbor_map != -1) - if 'cuda' in str(neighbor_map.device): - map_converted = torchsparse_backend.convert_map_forward( - neighbor_map.int(), idx_batch.int(), idx_point.int()) - elif 'cpu' in str(neighbor_map.device): - map_converted = torchsparse_backend.cpu_convert_map_forward( - neighbor_map.int(), idx_batch.int(), idx_point.int()) - else: - device = neighbor_map.device - map_converted = torchsparse_backend.cpu_convert_map_forward( - neighbor_map.int().cpu(), - idx_batch.int().cpu(), - idx_point.int().cpu()) - map_converted = map_converted.to(device) - nmap_offset = torch.sum(neighbor_map != -1, 1) - return map_converted.int().contiguous(), nmap_offset.int().contiguous() - - -convert_neighbor_map_gpu = ConvertNeighborMap.apply diff --git a/torchsparse/nn/functional/downsample.py b/torchsparse/nn/functional/downsample.py index fb35f81..7d7c31b 100644 --- a/torchsparse/nn/functional/downsample.py +++ b/torchsparse/nn/functional/downsample.py @@ -2,48 +2,69 @@ import torchsparse_backend from torch.autograd import Function from torchsparse.nn.functional.hash import * -from torchsparse.nn.functional.voxelize import spvoxelize +from torchsparse.utils.kernel import KernelRegion +from typing import Tuple, List, Union __all__ = ['spdownsample'] -class DownsampleGPU(Function): - @staticmethod - def forward(ctx, coords, ratio): - coords_float = coords[:, :3].float() - # following Minkowski engine - coords_new = torch.floor(torch.floor(coords_float / ratio) * - ratio).int() - coords_new = torch.cat([coords_new, coords[:, 3].view(-1, 1)], 1) - coords_new_hash = sphash(coords_new) - uq, inv, cnt = torch.unique(coords_new_hash, - return_inverse=True, - return_counts=True) - inv = inv.int() - cnt = cnt.int() - # rounding is necessary - # gpu - if 'cuda' in str(coords.device): - uq_coords = torch.round(spvoxelize(coords_new.float(), inv, - cnt)) - elif 'cpu' in str(coords.device): - uq_coords = torch.round( - torchsparse_backend.cpu_insertion_forward( - coords_new.float(), inv, cnt)) - else: - device = coords.device - uq_coords = torch.round( - torchsparse_backend.cpu_insertion_forward( - coords_new.float().cpu(), inv.cpu(), cnt.cpu())) - uq_coords = uq_coords.to(device) - uq_coords = uq_coords.int() - - # Notice: corrds_new_hash cannot be directly used - return uq_coords #, coords_new_hash - +def spdownsample( + coords: torch.Tensor, + ratio: Union[int, List[int], Tuple[int, int, int]] = 2, + kernel_size: Union[int, List[int], Tuple[int, int, int]] = 2, + tensor_stride: Union[int, List[int], Tuple[int, int, int]] = 1 +) -> torch.Tensor: -downsample_gpu = DownsampleGPU.apply + if not isinstance(ratio, int): + ratio = torch.IntTensor(ratio).to(coords.device).unsqueeze(0) + if not isinstance(tensor_stride, int): + tensor_stride = torch.IntTensor(tensor_stride).to( + coords.device).unsqueeze(0) + if isinstance(kernel_size, int) and isinstance(ratio, int): + direct_downsample = kernel_size == ratio + else: + if isinstance(kernel_size, int): + # ratio is a permutation of [1, 1, kernel_size] + direct_downsample = (kernel_size == ratio.prod().item()) & \ + (torch.sum(ratio == kernel_size) == 1).item() + else: + direct_downsample = False -def spdownsample(coords, ratio): - return downsample_gpu(coords, ratio) + if direct_downsample: + _ratio = ratio * tensor_stride + new_coords = torch.cat( + [coords[:, :3] // _ratio * _ratio, coords[:, 3:]], 1) + return torch.unique(new_coords, dim=0) + else: + kernel_region = KernelRegion(kernel_size, tensor_stride, dilation=1) + # kernel volume x 3 + kernel_offset = kernel_region.get_kernel_offset().to(coords.device) + new_coords = coords[:, :3].unsqueeze(1).repeat( + 1, kernel_offset.size(0), 1) + kernel_offset + # (N x kernel volume) x 4 + new_coords = torch.cat([ + coords[:, 3:].repeat(1, kernel_offset.size(0)).view(-1, 1), + new_coords.view(-1, 3) + ], + dim=1) + new_ts = tensor_stride * ratio + # only keep these coordinates that is multiple of new_ts. + if isinstance(new_ts, torch.Tensor): + new_ts = new_ts[0] + new_coords = new_coords[ + (new_coords[:, 1] % new_ts[0].item() == 0) & (new_coords[:, 2] % new_ts[1].item() == 0) & \ + (new_coords[:, 3] % new_ts[2].item() == 0) + ] + else: + new_coords = new_coords[ + (new_coords[:, 1] % new_ts == 0) & (new_coords[:, 2] % new_ts == 0) & \ + (new_coords[:, 3] % new_ts == 0) + ] + new_coords = new_coords[(new_coords[:, 1] >= 0) + & (new_coords[:, 2] >= 0) & + (new_coords[:, 3] >= 0)] + # filter out duplicates + new_coords = torch.unique(new_coords, dim=0) + new_coords = new_coords[:, [1, 2, 3, 0]] + return new_coords diff --git a/torchsparse/nn/functional/squeeze_nmap.py b/torchsparse/nn/functional/squeeze_nmap.py new file mode 100644 index 0000000..2bd3f17 --- /dev/null +++ b/torchsparse/nn/functional/squeeze_nmap.py @@ -0,0 +1,12 @@ +import torch + +__all__ = ['squeeze_nmap'] + + +def squeeze_nmap(neighbor_map: torch.Tensor) -> torch.Tensor: + idx_batch, idx_point = torch.where(neighbor_map != -1) + map_converted = neighbor_map.view(-1)[idx_batch * neighbor_map.size(1) + + idx_point] + map_converted = torch.stack([map_converted, idx_point], dim=1) + nmap_offset = torch.sum(neighbor_map != -1, 1) + return map_converted.int().contiguous(), nmap_offset.int().contiguous() diff --git a/torchsparse/nn/modules/activation.py b/torchsparse/nn/modules/activation.py index 70dadc0..cdce13d 100644 --- a/torchsparse/nn/modules/activation.py +++ b/torchsparse/nn/modules/activation.py @@ -1,9 +1,9 @@ import functools from torch import nn -from torchsparse.sparse_tensor import * +from torchsparse.sparse_tensor import SparseTensor -from ..functional import * +from torchsparse.nn import functional as spF __all__ = ['ReLU', 'LeakyReLU'] @@ -11,7 +11,7 @@ class Activation(nn.Module): def __init__(self, inplace: bool = True) -> None: super().__init__() - self.activation = spact + self.activation = spF.spact self.inplace = inplace def forward(self, inputs): @@ -21,7 +21,7 @@ def forward(self, inputs): class ReLU(Activation): def __init__(self, inplace: bool = True) -> None: super().__init__() - self.activation = functools.partial(sprelu, inplace=inplace) + self.activation = functools.partial(spF.sprelu, inplace=inplace) def __repr__(self): if self.inplace: @@ -35,7 +35,7 @@ def __init__(self, negative_slope: float = 0.1, inplace: bool = True) -> None: super().__init__() - self.activation = functools.partial(spleaky_relu, + self.activation = functools.partial(spF.spleaky_relu, negative_slope=negative_slope, inplace=inplace) self.negative_slope = negative_slope diff --git a/torchsparse/nn/modules/conv.py b/torchsparse/nn/modules/conv.py index 011a813..3dcb2af 100644 --- a/torchsparse/nn/modules/conv.py +++ b/torchsparse/nn/modules/conv.py @@ -2,12 +2,15 @@ import torch from torch import nn -from torchsparse.sparse_tensor import * +from torchsparse.sparse_tensor import SparseTensor +from torchsparse.nn import functional as spF +from torchsparse.utils.helpers import make_tuple -from ..functional import * +from typing import Union, List, Tuple __all__ = [ - 'Conv3d', 'ToBEVConvolution', 'ToBEVReduction', 'ToDenseBEVConvolution' + 'Conv3d', 'ToBEVConvolution', 'ToBEVReduction', 'ToDenseBEVConvolution', + 'ToBEVHeightCompression' ] @@ -15,17 +18,24 @@ class Conv3d(nn.Module): def __init__(self, in_channels: int, out_channels: int, - kernel_size: int = 3, - stride: int = 1, + kernel_size: Union[int, List[int], Tuple[int, int, int]] = 3, + stride: Union[int, List[int], Tuple[int, int, int]] = 1, dilation: int = 1, bias: bool = False, transpose: bool = False) -> None: super().__init__() self.in_channels = inc = in_channels self.out_channels = outc = out_channels - self.kernel_size = kernel_size - self.stride = stride + if isinstance(kernel_size, list): + self.kernel_size = tuple(kernel_size) + else: + self.kernel_size = kernel_size + if isinstance(stride, list): + self.stride = tuple(stride) + else: + self.stride = stride self.dilation = dilation + if not isinstance(kernel_size, (list, tuple)): self.kernel_volume = self.kernel_size ** 3 self.kernel = nn.Parameter( @@ -52,11 +62,11 @@ def __init__(self, def __repr__(self): if not self.t: - return 'Conv3d(in_channels=%d, out_channels=%d, kernel_size=%d, stride=%d, dilation=%d)' % ( + return 'Conv3d(in_channels={}, out_channels={}, kernel_size={}, stride={}, dilation={})'.format( self.in_channels, self.out_channels, self.kernel_size, self.stride, self.dilation) else: - return 'Conv3d(in_channels=%d, out_channels=%d, kernel_size=%d, stride=%d, dilation=%d)' % ( + return 'Conv3dTranspose(in_channels={}, out_channels={}, kernel_size={}, stride={}, dilation={})'.format( self.in_channels, self.out_channels, self.kernel_size, self.stride, self.dilation) @@ -68,14 +78,14 @@ def reset_parameters(self): if self.bias is not None: self.bias.data.uniform_(-std, std) - def forward(self, inputs): - return conv3d(inputs, - self.kernel, - kernel_size=self.kernel_size, - bias=self.bias, - stride=self.stride, - dilation=self.dilation, - transpose=self.t) + def forward(self, inputs: SparseTensor) -> SparseTensor: + return spF.conv3d(inputs, + self.kernel, + kernel_size=self.kernel_size, + bias=self.bias, + stride=self.stride, + dilation=self.dilation, + transpose=self.t) class ToBEVReduction(nn.Module): @@ -83,10 +93,10 @@ def __init__(self, dim: int = 1) -> None: super().__init__() self.dim = dim - def __repr__(self): - return 'ToBEVReduction(dim = %d)' % self.dim + def extra_repr(self): + return 'dim = {}'.format(self.dim) - def forward(self, inputs: SparseTensor): + def forward(self, inputs: SparseTensor) -> SparseTensor: coords, feats, stride = inputs.C, inputs.F, inputs.s coords = coords.clone() @@ -99,7 +109,86 @@ def forward(self, inputs: SparseTensor): return SparseTensor(coords=coords, feats=feats, stride=stride) +class ToDenseBEVConvolution(nn.Module): + """ + + Converts a torchsparse.SparseTensor to a BEV feature map. + Group points with the same z value together and apply the same FC kernel. + Aggregate the results by summing up all features within one BEV grid. + + in_channels: input channels + out_channels: output channels + shape: shape of BEV map. + dim: dimension index for z. (default: 1 for KITTI coords) + bias: whether to use bias. + + Warning: usually larger memory consumption than ToBEVHeightCompression. + + + """ + def __init__(self, + in_channels: int, + out_channels: int, + shape: Union[List[int], Tuple[int, int, int], torch.Tensor], + offset: List[int] = [0, 0, 0], + dim: int = 1, + bias: bool = False) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.register_buffer('offset', torch.IntTensor([list(offset) + [0]])) + if isinstance(shape, torch.Tensor): + self.register_buffer('shape', shape.int()) + else: + self.register_buffer('shape', torch.IntTensor(shape)) + self.dim = dim + self.n_kernels = int(self.shape[self.dim]) + self.bev_dims = [i for i in range(3) if i != self.dim] + self.bev_shape = self.shape[self.bev_dims] + self.kernel = nn.Parameter( + torch.zeros(self.n_kernels, in_channels, out_channels)) + self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0 + self.reset_parameters() + + def extra_repr(self): + return 'in_channels={}, out_channels={}, n_kernels={}'.format( + self.in_channels, self.out_channels, self.n_kernels) + + def reset_parameters(self): + std = 1. / math.sqrt(self.in_channels) + self.kernel.data.uniform_(-std, std) + + def forward(self, inputs: SparseTensor) -> torch.Tensor: + coords, feats, stride = inputs.C, inputs.F, inputs.s + if isinstance(stride, tuple): + stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] + + kernel = torch.index_select(self.kernel, 0, + (coords[:, self.dim] // stride).long()) + feats = (feats.unsqueeze(-1) * kernel).sum(1) + self.bias + coords = (coords - self.offset).t()[[3] + self.bev_dims].long() + coords[1:] = (coords[1:] // stride).long() + indices = coords[0] * int(self.bev_shape.prod()) + coords[1] * int( + self.bev_shape[1]) + coords[2] + batch_size = coords[0].max().item() + 1 + outputs = torch.sparse_coo_tensor( + indices.unsqueeze(0), + feats, + torch.Size( + [batch_size * int(self.bev_shape.prod()), + feats.size(-1)]), + ).to_dense() + outputs = outputs.view(batch_size, *self.bev_shape, -1) + outputs = outputs.permute(0, 3, 1, 2).contiguous() + return outputs + + class ToBEVConvolution(nn.Module): + """ + + Sparse version of ToDenseBEVConvolution. + + """ def __init__(self, in_channels: int, out_channels: int, @@ -122,16 +211,18 @@ def reset_parameters(self): std = 1. / math.sqrt(self.in_channels) self.kernel.data.uniform_(-std, std) - def __repr__(self): - return 'ToBEVConvolution(in_channels=%d, out_channels=%d, n_kernels=%d, stride=%d)' % ( + def extra_repr(self): + return 'in_channels={}, out_channels={}, n_kernels={}, stride={}'.format( self.in_channels, self.out_channels, self.n_kernels, self.stride) - def forward(self, inputs): + def forward(self, inputs: SparseTensor) -> torch.Tensor: coords, feats, stride = inputs.C, inputs.F, inputs.s ratio = stride * self.stride + if isinstance(stride, tuple): + stride = torch.Tensor(stride).unsqueeze(0).to(feats)[:, self.dim] kernels = torch.index_select(self.kernel, 0, - coords[:, self.dim].long() / stride) + coords[:, self.dim].long() // stride) feats = (feats.unsqueeze(-1) * kernels).sum(1) + self.bias coords = coords.t().long() coords[self.dim, :] = 0 @@ -143,53 +234,68 @@ def forward(self, inputs): flatten.indices().t().int(), ratio) -class ToDenseBEVConvolution(nn.Module): +class ToBEVHeightCompression(nn.Module): + """ + + Converts a torchsparse.SparseTensor to a dense volumetric tensor, + then flatten the z dimension. + E.g. input [N, C] (assume batch_size=1), spatial size [128,2,128] + then output will be [1, 2C, 128, 128] + + channels: input channels + (Note: output channels = channels x #unique z values) + shape: shape of BEV map. + dim: dimension index for z. (default: 1 for KITTI coords) + bias: whether to use bias. + + + """ def __init__(self, - in_channels: int, - out_channels: int, - shape, - offset: list = [0, 0, 0], + channels: int, + shape: Union[List[int], Tuple[int, int, int], torch.Tensor], + offset: List[int] = [0, 0, 0], dim: int = 1, bias: bool = False) -> None: super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.offset = torch.cuda.IntTensor([list(offset) + [0]]) + self.channels = channels + self.register_buffer('offset', torch.IntTensor([list(offset) + [0]])) + if isinstance(shape, torch.Tensor): + self.register_buffer('shape', shape.int()) + else: + self.register_buffer('shape', torch.IntTensor(shape)) self.dim = dim - self.n_kernels = int(shape[self.dim]) self.bev_dims = [i for i in range(3) if i != self.dim] - self.bev_shape = shape[self.bev_dims] - self.kernel = nn.Parameter( - torch.zeros(self.n_kernels, in_channels, out_channels)) - self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0 - self.reset_parameters() - - def __repr__(self): - return 'ToDenseBEVConvolution(in_channels=%d, out_channels=%d, n_kernels=%d)' % ( - self.in_channels, self.out_channels, self.n_kernels) + self.bev_shape = self.shape[self.bev_dims] - def reset_parameters(self): - std = 1. / math.sqrt(self.in_channels) - self.kernel.data.uniform_(-std, std) + def extra_repr(self): + return 'channels={}'.format(self.channels) - def forward(self, inputs: SparseTensor): + def forward(self, inputs: SparseTensor) -> torch.Tensor: coords, feats, stride = inputs.C, inputs.F, inputs.s + if isinstance(stride, tuple): + stride = torch.Tensor(stride).unsqueeze(0).to(feats) - kernel = torch.index_select(self.kernel, 0, - (coords[:, self.dim] / stride).long()) - feats = (feats.unsqueeze(-1) * kernel).sum(1) + self.bias - coords = (coords - self.offset).t()[[3] + self.bev_dims].long() - coords[1:] = (coords[1:] / stride).long() - indices = coords[0] * int(self.bev_shape.prod()) + coords[1] * int( - self.bev_shape[1]) + coords[2] + # [b, x, y, z] + coords = (coords - self.offset).t()[[3] + self.bev_dims + + [self.dim]].long() + shape = self.shape[self.bev_dims + [self.dim]] + + # now stride must be torch.Tensor since inputs.s is tuple. + dim = self.dim + stride = stride[:, self.bev_dims + [self.dim]] + stride = stride.t() + + coords[1:] = (coords[1:] // stride).long() + coords[-1] = torch.clamp(coords[-1], 0, shape[-1] - 1) + indices = coords[0] * int(shape.prod()) + coords[1] * int( + shape[1:].prod()) + coords[2] * int(shape[2]) + coords[3] batch_size = coords[0].max().item() + 1 - outputs = torch.cuda.sparse.FloatTensor( + outputs = torch.sparse_coo_tensor( indices.unsqueeze(0), feats, - torch.Size( - [batch_size * int(self.bev_shape.prod()), - feats.size(-1)]), + torch.Size([batch_size * int(self.shape.prod()), + feats.size(-1)]), ).to_dense() - outputs = outputs.view(batch_size, *self.bev_shape, -1) + outputs = outputs.view(batch_size, *self.bev_shape.cpu().numpy(), -1) outputs = outputs.permute(0, 3, 1, 2).contiguous() return outputs diff --git a/torchsparse/nn/modules/norm.py b/torchsparse/nn/modules/norm.py index eb6f462..7590114 100644 --- a/torchsparse/nn/modules/norm.py +++ b/torchsparse/nn/modules/norm.py @@ -1,6 +1,6 @@ import torch from torch import nn -from torchsparse.sparse_tensor import * +from torchsparse.sparse_tensor import SparseTensor __all__ = ['BatchNorm', 'GroupNorm'] @@ -39,15 +39,16 @@ def forward(self, inputs): # PyTorch's GroupNorm function expects the input to be in (N, C, *) format where # N is batch size, and C is number of channels. "feats" is not in that format. # So, we extract the feats corresponding to each sample, bring it to the format - # expected by PyTorch's GroupNorm function, and invoke it. - batch_size = coords[-1][-1] + 1 + # expected by PyTorch's GroupNorm function, and invoke it. + batch_size = coords[-1][-1] + 1 num_channels = feats.shape[1] new_feats = torch.zeros_like(feats) for sample_idx in range(batch_size): - indices = coords[:,-1] == sample_idx + indices = coords[:, -1] == sample_idx sample_feats = feats[indices] sample_feats = torch.transpose(sample_feats, 0, 1) - sample_feats = sample_feats.reshape(1, num_channels, -1) # N=1. since we have a single sample here + sample_feats = sample_feats.reshape( + 1, num_channels, -1) # N=1. since we have a single sample here normalized_feats = super().forward(sample_feats) normalized_feats = normalized_feats.reshape(num_channels, -1) normalized_feats = torch.transpose(normalized_feats, 0, 1) diff --git a/torchsparse/nn/modules/pooling.py b/torchsparse/nn/modules/pooling.py index 41bc8a8..947f5e2 100644 --- a/torchsparse/nn/modules/pooling.py +++ b/torchsparse/nn/modules/pooling.py @@ -1,15 +1,17 @@ from torch import nn -from torchsparse.sparse_tensor import * +from torchsparse.sparse_tensor import SparseTensor -from ..functional import * +from torchsparse.nn import functional as spF + +__all__ = ['GlobalAveragePooling', 'GlobalMaxPooling'] class GlobalAveragePooling(nn.Module): def forward(self, inputs): - return global_avg_pool(inputs) + return spF.global_avg_pool(inputs) class GlobalMaxPooling(nn.Module): def forward(self, inputs): - return global_max_pool(inputs) + return spF.global_max_pool(inputs) diff --git a/torchsparse/sparse_tensor.py b/torchsparse/sparse_tensor.py index 68accbb..095d4a3 100644 --- a/torchsparse/sparse_tensor.py +++ b/torchsparse/sparse_tensor.py @@ -1,13 +1,24 @@ +import numpy as np import torch +from typing import Union, List, Tuple __all__ = ['SparseTensor'] class SparseTensor: - def __init__(self, feats, coords, stride=1): + def __init__( + self, + feats: Union[np.ndarray, torch.Tensor], + coords: Union[np.ndarray, torch.Tensor], + stride: Union[int, List[int], Tuple[int, int, int]] = 1) -> None: self.F = feats self.C = coords - self.s = stride + if isinstance(stride, int): + self.s = (stride, stride, stride) + elif isinstance(stride, list): + self.s = tuple(stride) + else: + self.s = stride self.coord_maps = {} self.kernel_maps = {} diff --git a/torchsparse/src/interpolation/devox_cpu.cpp b/torchsparse/src/interpolation/devox_cpu.cpp index 1961f07..4bf18c1 100644 --- a/torchsparse/src/interpolation/devox_cpu.cpp +++ b/torchsparse/src/interpolation/devox_cpu.cpp @@ -9,8 +9,6 @@ at::Tensor cpu_devoxelize_forward( const at::Tensor indices, const at::Tensor weight) { - //int b = feat.size(0); - //printf("%d\n", b); int c = feat.size(1); int N = indices.size(0); diff --git a/torchsparse/src/interpolation/devox_deterministic.cpp b/torchsparse/src/interpolation/devox_deterministic.cpp index 0c79f93..ec785bc 100644 --- a/torchsparse/src/interpolation/devox_deterministic.cpp +++ b/torchsparse/src/interpolation/devox_deterministic.cpp @@ -9,8 +9,6 @@ at::Tensor deterministic_devoxelize_forward( const at::Tensor indices, const at::Tensor weight) { - //int b = feat.size(0); - //printf("%d\n", b); int c = feat.size(1); int N = indices.size(0); @@ -32,7 +30,6 @@ at::Tensor deterministic_devoxelize_backward( deterministic_devoxelize_grad_wrapper(N, n, c, indices.data_ptr(), weight.data_ptr(), top_grad.data_ptr(), bottom_grad_int.data_ptr()); at::Tensor bottom_grad = bottom_grad_int.to(at::ScalarType::Double); - //std::cout << torch::mean(bottom_grad) << std::endl; bottom_grad /= 1e10; return bottom_grad.to(at::ScalarType::Float); } diff --git a/torchsparse/src/others/convert_neighbor_map.cpp b/torchsparse/src/others/convert_neighbor_map.cpp deleted file mode 100644 index 66f6d0b..0000000 --- a/torchsparse/src/others/convert_neighbor_map.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include -#include "convert_neighbor_map_gpu.h" -#include - -at::Tensor convert_map_forward( - const at::Tensor nmap, - const at::Tensor idx_batch, - const at::Tensor idx_point) -{ - //return group_point_forward_gpu(points, indices); - - int N = nmap.size(1); - int k = nmap.size(0); - int N_nonzero = idx_point.size(0); - at::Tensor out = torch::zeros({N_nonzero, 2}, at::device(nmap.device()).dtype(at::ScalarType::Int)); - convert_map_wrapper(k, N, N_nonzero, nmap.data_ptr(), idx_batch.data_ptr(), idx_point.data_ptr(), out.data_ptr()); - return out; -} diff --git a/torchsparse/src/others/convert_neighbor_map_cpu.cpp b/torchsparse/src/others/convert_neighbor_map_cpu.cpp deleted file mode 100644 index 2cbc9f6..0000000 --- a/torchsparse/src/others/convert_neighbor_map_cpu.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include "convert_neighbor_map_cpu_header.h" -#include - -void cpu_convert_map_wrapper(int k, int N, int N_nonzero, const int *nmap, const int *idx_batch, const int *idx_point, int *out) -{ -#pragma omp parallel for - for (int index = 0; index < N_nonzero; index++) - { - int i = idx_batch[index]; - int j = idx_point[index]; - out[index << 1] = nmap[i * N + j]; - out[(index << 1) + 1] = j; - } -} - -at::Tensor cpu_convert_map_forward( - const at::Tensor nmap, - const at::Tensor idx_batch, - const at::Tensor idx_point) -{ - //return group_point_forward_gpu(points, indices); - - int N = nmap.size(1); - int k = nmap.size(0); - int N_nonzero = idx_point.size(0); - at::Tensor out = torch::zeros({N_nonzero, 2}, at::device(nmap.device()).dtype(at::ScalarType::Int)); - cpu_convert_map_wrapper(k, N, N_nonzero, nmap.data_ptr(), idx_batch.data_ptr(), idx_point.data_ptr(), out.data_ptr()); - return out; -} diff --git a/torchsparse/src/others/convert_neighbor_map_cpu_header.h b/torchsparse/src/others/convert_neighbor_map_cpu_header.h deleted file mode 100644 index 7bd9dc5..0000000 --- a/torchsparse/src/others/convert_neighbor_map_cpu_header.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _CONVERT_NEIGHBOR_MAP_CPU -#define _CONVERT_NEIGHBOR_MAP_CPU -#include -#include - - -//CUDA forward declarations -void cpu_convert_map_wrapper(int k, int N, int N_nonzero, const int * nmap, const int * idx_batch, const int * idx_point, int * out); -at::Tensor cpu_convert_map_forward( - const at::Tensor nmap, - const at::Tensor idx_batch, - const at::Tensor idx_point -); - -#endif \ No newline at end of file diff --git a/torchsparse/src/others/convert_neighbor_map_gpu.cu b/torchsparse/src/others/convert_neighbor_map_gpu.cu deleted file mode 100644 index 15c510c..0000000 --- a/torchsparse/src/others/convert_neighbor_map_gpu.cu +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include -#include - -//hashing -//input N*F float tensor, pointer to output N'*F int64 tensor, N*1 count tensor, N*1 index tensor -__global__ void convert_map_kernel(int k, int N, int N_nonzero, const int *__restrict__ nmap, const int *__restrict__ idx_batch, const int *__restrict__ idx_point, int *__restrict__ out){ - int index = blockDim.x * blockIdx.x + threadIdx.x; - - if(index < N_nonzero){ - int i = idx_batch[index]; - int j = idx_point[index]; - out[index << 1] = nmap[i * N + j]; - out[(index << 1) + 1] = j; - } -} - - - -void convert_map_wrapper(int k, int N, int N_nonzero, const int * nmap, const int * idx_batch, const int * idx_point, int * out){ - convert_map_kernel<<>>(k, N, N_nonzero, nmap, idx_batch, idx_point, out); -} diff --git a/torchsparse/src/others/convert_neighbor_map_gpu.h b/torchsparse/src/others/convert_neighbor_map_gpu.h deleted file mode 100644 index ecadadc..0000000 --- a/torchsparse/src/others/convert_neighbor_map_gpu.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _CONVERT_NEIGHBOR_MAP -#define _CONVERT_NEIGHBOR_MAP -#include -#include - - -//CUDA forward declarations -void convert_map_wrapper(int k, int N, int N_nonzero, const int * nmap, const int * idx_batch, const int * idx_point, int * out); -at::Tensor convert_map_forward( - const at::Tensor nmap, - const at::Tensor idx_batch, - const at::Tensor idx_point -); - -#endif \ No newline at end of file diff --git a/torchsparse/src/torchsparse_bindings.cpp b/torchsparse/src/torchsparse_bindings.cpp index 79e5a8d..d12f5f3 100644 --- a/torchsparse/src/torchsparse_bindings.cpp +++ b/torchsparse/src/torchsparse_bindings.cpp @@ -4,7 +4,6 @@ #include "convolution/convolution_cpu_header.h" #include "hash/hash_cpu_header.h" #include "interpolation/devox_cpu_header.h" -#include "others/convert_neighbor_map_cpu_header.h" #include "others/insertion_cpu_header.h" #include "others/query_cpu_header.h" #include "others/count_cpu_header.h" @@ -15,7 +14,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sparseconv_cpu_backward", &ConvolutionBackwardCPU, "point cloud convolution backward (CPU)"); m.def("cpu_hash_forward", &cpu_hash_forward, "Hashing forward (CPU)"); m.def("cpu_kernel_hash_forward", &cpu_kernel_hash_forward, "Kernel Hashing forward (CPU)"); - m.def("cpu_convert_map_forward", &cpu_convert_map_forward, "Convert neighbor map forward (CPU)"); m.def("cpu_insertion_forward", &cpu_insertion_forward, "Insertion forward (CPU)"); m.def("cpu_insertion_backward", &cpu_insertion_backward, "Insertion backward (CPU)"); m.def("cpu_devoxelize_forward", &cpu_devoxelize_forward, "Devoxelization forward (CPU)"); diff --git a/torchsparse/src/torchsparse_bindings_gpu.cpp b/torchsparse/src/torchsparse_bindings_gpu.cpp index fdbd7cf..e27977f 100644 --- a/torchsparse/src/torchsparse_bindings_gpu.cpp +++ b/torchsparse/src/torchsparse_bindings_gpu.cpp @@ -3,14 +3,12 @@ #include #include "convolution/convolution_cpu_header.h" #include "hash/hash_cpu_header.h" -#include "others/convert_neighbor_map_cpu_header.h" #include "others/insertion_cpu_header.h" #include "others/query_cpu_header.h" #include "convolution/convolution_gpu.h" #include "hash/hash_gpu.h" #include "interpolation/devox_gpu.h" #include "interpolation/devox_cpu_header.h" -#include "others/convert_neighbor_map_gpu.h" #include "others/count_gpu.h" #include "others/insertion_gpu.h" #include "others/insertion_cpu_header.h" @@ -22,7 +20,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("sparseconv_cpu_forward", &ConvolutionForwardCPU, "point cloud convolution forward (CPU)"); m.def("sparseconv_cpu_backward", &ConvolutionBackwardCPU, "point cloud convolution backward (CPU)"); m.def("cpu_kernel_hash_forward", &cpu_kernel_hash_forward, "Kernel Hashing forward (CPU)"); - m.def("cpu_convert_map_forward", &cpu_convert_map_forward, "Convert neighbor map forward (CPU)"); m.def("cpu_insertion_forward", &cpu_insertion_forward, "Insertion forward (CPU)"); m.def("cpu_insertion_backward", &cpu_insertion_backward, "Insertion backward (CPU)"); m.def("cpu_query_forward", &cpu_query_forward, "hash query forward (CPU)"); @@ -44,5 +41,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("cpu_insertion_forward", &cpu_insertion_forward, "Insertion forward (CPU)"); m.def("cpu_insertion_backward", &cpu_insertion_backward, "Insertion backward (CPU)"); m.def("query_forward", &query_forward, "hash query forward (CUDA)"); - m.def("convert_map_forward", &convert_map_forward, "Convert neighbor map forward (CUDA)"); } diff --git a/torchsparse/utils/__init__.py b/torchsparse/utils/__init__.py index b697302..4215378 100644 --- a/torchsparse/utils/__init__.py +++ b/torchsparse/utils/__init__.py @@ -1,2 +1,2 @@ from .helpers import * -from .kernel_region import * \ No newline at end of file +from .kernel import * \ No newline at end of file diff --git a/torchsparse/utils/helpers.py b/torchsparse/utils/helpers.py index dd8b3b9..0830295 100644 --- a/torchsparse/utils/helpers.py +++ b/torchsparse/utils/helpers.py @@ -4,6 +4,11 @@ import torch from torchsparse import SparseTensor +__all__ = [ + 'ravel_hash_vec', 'sparse_quantize', 'sparse_collate', 'sparse_collate_fn', + 'sparse_collate_tensors', 'make_tuple' +] + def ravel_hash_vec(arr): assert arr.ndim == 2 @@ -151,12 +156,16 @@ def sparse_collate(coords, if not coord_float: coords_batch.append( - torch.cat((coord, torch.ones((num_points, 1), device=coord.device).int() * batch_id), - 1)) + torch.cat( + (coord, torch.ones( + (num_points, 1), device=coord.device).int() * + batch_id), 1)) else: coords_batch.append( torch.cat( - (coord, torch.ones((num_points, 1), device=coord.device).float() * batch_id), 1)) + (coord, torch.ones( + (num_points, 1), device=coord.device).float() * + batch_id), 1)) # Features feats_batch.append(feat) @@ -233,3 +242,17 @@ def sparse_collate_fn(batch): else: ans_dict += [sample[i] for sample in batch], return ans_dict + + +def make_tuple(inputs, dimension=3): + if isinstance(inputs, int): + outputs = tuple() + for d in range(dimension): + outputs += inputs, + return outputs + elif isinstance(inputs, list): + assert len(inputs) == dimension, 'Input length and dimension mismatch' + return tuple(inputs) + elif isinstance(inputs, tuple): + assert len(inputs) == dimension, 'Input length and dimension mismatch' + return inputs diff --git a/torchsparse/utils/kernel_region.py b/torchsparse/utils/kernel.py similarity index 56% rename from torchsparse/utils/kernel_region.py rename to torchsparse/utils/kernel.py index 51278cb..29c18d1 100644 --- a/torchsparse/utils/kernel_region.py +++ b/torchsparse/utils/kernel.py @@ -1,18 +1,30 @@ +from collections import namedtuple import numpy as np import torch +from typing import Union, List, Tuple +from torchsparse.utils import make_tuple -__all__ = ['KernelRegion'] +__all__ = ['KernelRegion', 'KernelMapKey'] + +KernelMapKey = namedtuple('KernelMapKey', + ['kernel_size', 'cur_stride', 'stride', 'dilation']) class KernelRegion: def __init__(self, - kernel_size: int = 3, - tensor_stride: int = 1, - dilation: int = 1, - dim=[0, 1, 2]) -> None: + kernel_size: Union[int, List[int], Tuple[int, int, int]] = 3, + tensor_stride: Union[int, List[int], Tuple[int, int, int], + torch.Tensor] = 1, + dilation: Union[int, List[int], Tuple[int, int, int]] = 1, + dim: List[int] = [0, 1, 2]) -> None: self.kernel_size = kernel_size - self.tensor_stride = tensor_stride - self.dilation = dilation + self.tensor_stride = make_tuple(tensor_stride) + self.dilation = make_tuple(dilation) + assert len(self.tensor_stride) == 3, 'Wrong tensor_stride' + assert len(self.dilation) == 3, 'Wrong dilation' + + ts = self.tensor_stride + d = self.dilation if not isinstance(kernel_size, (list, tuple)): if kernel_size % 2 == 0: @@ -24,13 +36,15 @@ def __init__(self, self.region_type = region_type - single_offset = ( + x_offset = ( np.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1) * - tensor_stride * dilation).tolist() - - x_offset = single_offset if 0 in dim else [0] - y_offset = single_offset if 1 in dim else [0] - z_offset = single_offset if 2 in dim else [0] + ts[0] * d[0]).tolist() if 0 in dim else [0] + y_offset = ( + np.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1) * + ts[1] * d[1]).tolist() if 1 in dim else [0] + z_offset = ( + np.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1) * + ts[2] * d[2]).tolist() if 2 in dim else [0] if self.region_type == 1: kernel_offset = [[x, y, z] for z in z_offset for y in y_offset @@ -47,14 +61,14 @@ def __init__(self, kernel_z_size = kernel_size[2] x_offset = (np.arange(-kernel_x_size // 2 + 1, - kernel_x_size // 2 + 1) * tensor_stride * - dilation).tolist() + kernel_x_size // 2 + 1) * ts[0] * + d[0]).tolist() y_offset = (np.arange(-kernel_y_size // 2 + 1, - kernel_y_size // 2 + 1) * tensor_stride * - dilation).tolist() + kernel_y_size // 2 + 1) * ts[1] * + d[1]).tolist() z_offset = (np.arange(-kernel_z_size // 2 + 1, - kernel_z_size // 2 + 1) * tensor_stride * - dilation).tolist() + kernel_z_size // 2 + 1) * ts[2] * + d[2]).tolist() kernel_offset = [[x, y, z] for x in x_offset for y in y_offset for z in z_offset]