Skip to content

Commit

Permalink
Add generalized sparse convolution (#77)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhijian Liu <[email protected]>
  • Loading branch information
kentang-mit and Zhijian Liu authored Jun 21, 2021
1 parent 63a67ed commit 313c9b0
Show file tree
Hide file tree
Showing 24 changed files with 393 additions and 322 deletions.
4 changes: 0 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
'torchsparse/src/interpolation/devox_deterministic.cpp',
'torchsparse/src/interpolation/devox_deterministic_gpu.cu',
'torchsparse/src/interpolation/devox_cpu.cpp',
'torchsparse/src/others/convert_neighbor_map.cpp',
'torchsparse/src/others/convert_neighbor_map_gpu.cu',
'torchsparse/src/others/convert_neighbor_map_cpu.cpp',
'torchsparse/src/others/count.cpp',
'torchsparse/src/others/count_gpu.cu',
'torchsparse/src/others/count_cpu.cpp',
Expand All @@ -44,7 +41,6 @@
'torchsparse/src/hash/hash_cpu.cpp',
'torchsparse/src/hashmap/hashmap_cpu.cpp',
'torchsparse/src/interpolation/devox_cpu.cpp',
'torchsparse/src/others/convert_neighbor_map_cpu.cpp',
'torchsparse/src/others/insertion_cpu.cpp',
'torchsparse/src/others/query_cpu.cpp',
'torchsparse/src/others/count_cpu.cpp'
Expand Down
2 changes: 1 addition & 1 deletion torchsparse/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .activation import *
from .conv import *
from .convert_neighbor_map import *
from .squeeze_nmap import *
from .count import *
from .crop import *
from .devox import *
Expand Down
2 changes: 1 addition & 1 deletion torchsparse/nn/functional/activation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools

from torch.nn import functional as F
from torchsparse.sparse_tensor import *
from torchsparse.sparse_tensor import SparseTensor

__all__ = ['spact', 'sprelu', 'spleaky_relu']

Expand Down
109 changes: 66 additions & 43 deletions torchsparse/nn/functional/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import torchsparse_backend
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
from torchsparse import *
from torchsparse.nn.functional.convert_neighbor_map import *
from torchsparse.nn.functional.downsample import *
from torchsparse.nn.functional.hash import *
from torchsparse.nn.functional.query import *
from torchsparse.utils.kernel_region import *
from torchsparse import SparseTensor
from torchsparse.nn import functional as spF
from torchsparse.utils.helpers import make_tuple
from torchsparse.utils.kernel import KernelRegion, KernelMapKey

from typing import Union, List, Tuple, Optional

__all__ = ['conv3d']

Expand Down Expand Up @@ -70,8 +70,15 @@ def backward(ctx, grad_out):
features, kernel, neighbor_map, neighbor_offset, transpose = ctx.for_backwards
K, c_in, c_out = kernel.size()
N_in = features.size(0)
grad_features = torch.zeros(N_in, c_in, device=features.device, dtype=features.dtype)
grad_kernel = torch.zeros(K, c_in, c_out, device=kernel.device, dtype=features.dtype)
grad_features = torch.zeros(N_in,
c_in,
device=features.device,
dtype=features.dtype)
grad_kernel = torch.zeros(K,
c_in,
c_out,
device=kernel.device,
dtype=features.dtype)

if 'cuda' in str(features.device):
torchsparse_backend.sparseconv_backward(features, grad_features,
Expand All @@ -87,18 +94,24 @@ def backward(ctx, grad_out):
sparseconv_op = SpConvolution.apply


def conv3d(inputs,
kernel,
kernel_size,
bias=None,
stride=1,
dilation=1,
transpose=False):
def conv3d(inputs: SparseTensor,
kernel: torch.Tensor,
kernel_size: Union[int, List[int], Tuple[int, int, int]],
bias: Optional[torch.Tensor] = None,
stride: Union[int, List[int], Tuple[int, int, int]] = 1,
dilation: Union[int, List[int], Tuple[int, int, int]] = 1,
transpose: bool = False) -> SparseTensor:
features = inputs.F
coords = inputs.C
cur_stride = inputs.s

if kernel_size == 1 and stride == 1 and dilation == 1:
# convert to hashable types
kernel_size = make_tuple(kernel_size)
stride = make_tuple(stride)
dilation = make_tuple(dilation)

if kernel_size == (1, 1, 1) and stride == (1, 1, 1) and dilation == (1, 1,
1):
output_features = features.matmul(kernel)
if bias is not None:
output_features += bias
Expand All @@ -107,34 +120,37 @@ def conv3d(inputs,
output_tensor.kernel_maps = inputs.kernel_maps
output_tensor.check()
elif not transpose:
kernel_map = inputs.kernel_maps.get(
'k%s_os%d_s%d_d%d' % (kernel_size, cur_stride, stride, dilation),
None)
kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
dilation)
kernel_map = inputs.kernel_maps.get(kernel_map_key, None)

if stride > 1:
if any(x > 1 for x in stride):
# do downsample
kRegion = KernelRegion(kernel_size=kernel_size,
tensor_stride=cur_stride)
kOffset = kRegion.get_kernel_offset().to(features.device)
new_coords = spdownsample(coords, stride * cur_stride)
hash_query = sphash(new_coords, kOffset)
hash_target = sphash(coords)
idx_query = sphashquery(hash_query, hash_target)
idx_query = list(convert_neighbor_map_gpu(idx_query))
new_coords = spF.spdownsample(coords, stride, kernel_size,
cur_stride)
hash_query = spF.sphash(new_coords, kOffset)
hash_target = spF.sphash(coords)
idx_query = spF.sphashquery(hash_query, hash_target)
idx_query = list(spF.squeeze_nmap(idx_query))
idx_query[1] = idx_query[1].to('cpu')
sizes = (features.shape[0], new_coords.shape[0])
output_features = sparseconv_op(features, kernel, idx_query[0],
idx_query[1], sizes, transpose)
if bias is not None:
output_features += bias
output_tensor = SparseTensor(output_features, new_coords,
cur_stride * stride)
output_tensor = SparseTensor(
output_features, new_coords,
[a * b for a, b in zip(cur_stride, stride)])
output_tensor.coord_maps = copy.deepcopy(inputs.coord_maps)
output_tensor.check()

kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
dilation)
output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps)
output_tensor.kernel_maps['k%s_os%d_s%d_d%d' %
(kernel_size, cur_stride, stride,
dilation)] = idx_query + [sizes]
output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes]

else:
if kernel_map is None:
Expand All @@ -144,10 +160,10 @@ def conv3d(inputs,
kOffset = kRegion.get_kernel_offset().to(features.device)
except:
raise
hash_query = sphash(coords, kOffset)
hash_target = sphash(coords)
idx_query = sphashquery(hash_query, hash_target)
idx_query = list(convert_neighbor_map_gpu(idx_query))
hash_query = spF.sphash(coords, kOffset)
hash_target = spF.sphash(coords)
idx_query = spF.sphashquery(hash_query, hash_target)
idx_query = list(spF.squeeze_nmap(idx_query))
idx_query[1] = idx_query[1].to('cpu')
sizes = (features.shape[0], features.shape[0])
output_features = sparseconv_op(features, kernel, idx_query[0],
Expand All @@ -159,9 +175,9 @@ def conv3d(inputs,
output_tensor.coord_maps = inputs.coord_maps
output_tensor.check()
output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps)
output_tensor.kernel_maps['k%s_os%d_s%d_d%d' %
(kernel_size, cur_stride, stride,
dilation)] = idx_query + [sizes]
kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
dilation)
output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes]
else:
output_features = sparseconv_op(features, kernel,
kernel_map[0], kernel_map[1],
Expand All @@ -176,17 +192,24 @@ def conv3d(inputs,

else:
# do upsample
original_stride = int(cur_stride / stride)
kernel_map = inputs.kernel_maps.get(
'k%s_os%d_s%d_d%d' %
(kernel_size, original_stride, stride, dilation), None)

original_stride = tuple(
[int(a / b) for a, b in zip(cur_stride, stride)])

kernel_map_key = KernelMapKey(kernel_size, original_stride, stride,
dilation)
kernel_map = inputs.kernel_maps.get(kernel_map_key, None)
assert kernel_map is not None, f'{kernel_map_key} does not exist.'
output_features = sparseconv_op(features, kernel, kernel_map[0],
kernel_map[1], kernel_map[2],
transpose)
if bias is not None:
output_features += bias
output_tensor = SparseTensor(output_features,
inputs.coord_maps[original_stride],

cur_coords = inputs.coord_maps.get(original_stride, None)
assert cur_coords is not None, f'{original_stride} not in coord maps.'

output_tensor = SparseTensor(output_features, cur_coords,
original_stride)
output_tensor.coord_maps = inputs.coord_maps
output_tensor.check()
Expand Down
27 changes: 0 additions & 27 deletions torchsparse/nn/functional/convert_neighbor_map.py

This file was deleted.

97 changes: 59 additions & 38 deletions torchsparse/nn/functional/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,69 @@
import torchsparse_backend
from torch.autograd import Function
from torchsparse.nn.functional.hash import *
from torchsparse.nn.functional.voxelize import spvoxelize
from torchsparse.utils.kernel import KernelRegion
from typing import Tuple, List, Union

__all__ = ['spdownsample']


class DownsampleGPU(Function):
@staticmethod
def forward(ctx, coords, ratio):
coords_float = coords[:, :3].float()
# following Minkowski engine
coords_new = torch.floor(torch.floor(coords_float / ratio) *
ratio).int()
coords_new = torch.cat([coords_new, coords[:, 3].view(-1, 1)], 1)
coords_new_hash = sphash(coords_new)
uq, inv, cnt = torch.unique(coords_new_hash,
return_inverse=True,
return_counts=True)
inv = inv.int()
cnt = cnt.int()
# rounding is necessary
# gpu
if 'cuda' in str(coords.device):
uq_coords = torch.round(spvoxelize(coords_new.float(), inv,
cnt))
elif 'cpu' in str(coords.device):
uq_coords = torch.round(
torchsparse_backend.cpu_insertion_forward(
coords_new.float(), inv, cnt))
else:
device = coords.device
uq_coords = torch.round(
torchsparse_backend.cpu_insertion_forward(
coords_new.float().cpu(), inv.cpu(), cnt.cpu()))
uq_coords = uq_coords.to(device)
uq_coords = uq_coords.int()

# Notice: corrds_new_hash cannot be directly used
return uq_coords #, coords_new_hash

def spdownsample(
coords: torch.Tensor,
ratio: Union[int, List[int], Tuple[int, int, int]] = 2,
kernel_size: Union[int, List[int], Tuple[int, int, int]] = 2,
tensor_stride: Union[int, List[int], Tuple[int, int, int]] = 1
) -> torch.Tensor:

downsample_gpu = DownsampleGPU.apply
if not isinstance(ratio, int):
ratio = torch.IntTensor(ratio).to(coords.device).unsqueeze(0)
if not isinstance(tensor_stride, int):
tensor_stride = torch.IntTensor(tensor_stride).to(
coords.device).unsqueeze(0)

if isinstance(kernel_size, int) and isinstance(ratio, int):
direct_downsample = kernel_size == ratio
else:
if isinstance(kernel_size, int):
# ratio is a permutation of [1, 1, kernel_size]
direct_downsample = (kernel_size == ratio.prod().item()) & \
(torch.sum(ratio == kernel_size) == 1).item()
else:
direct_downsample = False

def spdownsample(coords, ratio):
return downsample_gpu(coords, ratio)
if direct_downsample:
_ratio = ratio * tensor_stride
new_coords = torch.cat(
[coords[:, :3] // _ratio * _ratio, coords[:, 3:]], 1)
return torch.unique(new_coords, dim=0)
else:
kernel_region = KernelRegion(kernel_size, tensor_stride, dilation=1)
# kernel volume x 3
kernel_offset = kernel_region.get_kernel_offset().to(coords.device)
new_coords = coords[:, :3].unsqueeze(1).repeat(
1, kernel_offset.size(0), 1) + kernel_offset
# (N x kernel volume) x 4
new_coords = torch.cat([
coords[:, 3:].repeat(1, kernel_offset.size(0)).view(-1, 1),
new_coords.view(-1, 3)
],
dim=1)
new_ts = tensor_stride * ratio
# only keep these coordinates that is multiple of new_ts.
if isinstance(new_ts, torch.Tensor):
new_ts = new_ts[0]
new_coords = new_coords[
(new_coords[:, 1] % new_ts[0].item() == 0) & (new_coords[:, 2] % new_ts[1].item() == 0) & \
(new_coords[:, 3] % new_ts[2].item() == 0)
]
else:
new_coords = new_coords[
(new_coords[:, 1] % new_ts == 0) & (new_coords[:, 2] % new_ts == 0) & \
(new_coords[:, 3] % new_ts == 0)
]
new_coords = new_coords[(new_coords[:, 1] >= 0)
& (new_coords[:, 2] >= 0) &
(new_coords[:, 3] >= 0)]
# filter out duplicates
new_coords = torch.unique(new_coords, dim=0)
new_coords = new_coords[:, [1, 2, 3, 0]]
return new_coords
12 changes: 12 additions & 0 deletions torchsparse/nn/functional/squeeze_nmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch

__all__ = ['squeeze_nmap']


def squeeze_nmap(neighbor_map: torch.Tensor) -> torch.Tensor:
idx_batch, idx_point = torch.where(neighbor_map != -1)
map_converted = neighbor_map.view(-1)[idx_batch * neighbor_map.size(1) +
idx_point]
map_converted = torch.stack([map_converted, idx_point], dim=1)
nmap_offset = torch.sum(neighbor_map != -1, 1)
return map_converted.int().contiguous(), nmap_offset.int().contiguous()
Loading

0 comments on commit 313c9b0

Please sign in to comment.