Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized Sparse Convolution #77

Merged
merged 19 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
'torchsparse/src/interpolation/devox_deterministic.cpp',
'torchsparse/src/interpolation/devox_deterministic_gpu.cu',
'torchsparse/src/interpolation/devox_cpu.cpp',
'torchsparse/src/others/convert_neighbor_map.cpp',
'torchsparse/src/others/convert_neighbor_map_gpu.cu',
'torchsparse/src/others/convert_neighbor_map_cpu.cpp',
'torchsparse/src/others/count.cpp',
'torchsparse/src/others/count_gpu.cu',
'torchsparse/src/others/count_cpu.cpp',
Expand All @@ -44,7 +41,6 @@
'torchsparse/src/hash/hash_cpu.cpp',
'torchsparse/src/hashmap/hashmap_cpu.cpp',
'torchsparse/src/interpolation/devox_cpu.cpp',
'torchsparse/src/others/convert_neighbor_map_cpu.cpp',
'torchsparse/src/others/insertion_cpu.cpp',
'torchsparse/src/others/query_cpu.cpp',
'torchsparse/src/others/count_cpu.cpp'
Expand Down
2 changes: 1 addition & 1 deletion torchsparse/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .activation import *
from .conv import *
from .convert_neighbor_map import *
from .squeeze_nmap import *
from .count import *
from .crop import *
from .devox import *
Expand Down
2 changes: 1 addition & 1 deletion torchsparse/nn/functional/activation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools

from torch.nn import functional as F
from torchsparse.sparse_tensor import *
from torchsparse.sparse_tensor import SparseTensor

__all__ = ['spact', 'sprelu', 'spleaky_relu']

Expand Down
109 changes: 66 additions & 43 deletions torchsparse/nn/functional/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import torchsparse_backend
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
from torchsparse import *
from torchsparse.nn.functional.convert_neighbor_map import *
from torchsparse.nn.functional.downsample import *
from torchsparse.nn.functional.hash import *
from torchsparse.nn.functional.query import *
from torchsparse.utils.kernel_region import *
from torchsparse import SparseTensor
from torchsparse.nn import functional as spF
from torchsparse.utils.helpers import make_tuple
from torchsparse.utils.kernel import KernelRegion, KernelMapKey

from typing import Union, List, Tuple, Optional

__all__ = ['conv3d']

Expand Down Expand Up @@ -70,8 +70,15 @@ def backward(ctx, grad_out):
features, kernel, neighbor_map, neighbor_offset, transpose = ctx.for_backwards
K, c_in, c_out = kernel.size()
N_in = features.size(0)
grad_features = torch.zeros(N_in, c_in, device=features.device, dtype=features.dtype)
grad_kernel = torch.zeros(K, c_in, c_out, device=kernel.device, dtype=features.dtype)
grad_features = torch.zeros(N_in,
c_in,
device=features.device,
dtype=features.dtype)
grad_kernel = torch.zeros(K,
c_in,
c_out,
device=kernel.device,
dtype=features.dtype)

if 'cuda' in str(features.device):
torchsparse_backend.sparseconv_backward(features, grad_features,
Expand All @@ -87,18 +94,24 @@ def backward(ctx, grad_out):
sparseconv_op = SpConvolution.apply


def conv3d(inputs,
kernel,
kernel_size,
bias=None,
stride=1,
dilation=1,
transpose=False):
def conv3d(inputs: SparseTensor,
kernel: torch.Tensor,
kernel_size: Union[int, List[int], Tuple[int, int, int]],
bias: Optional[torch.Tensor] = None,
stride: Union[int, List[int], Tuple[int, int, int]] = 1,
dilation: Union[int, List[int], Tuple[int, int, int]] = 1,
transpose: bool = False) -> SparseTensor:
features = inputs.F
coords = inputs.C
cur_stride = inputs.s

if kernel_size == 1 and stride == 1 and dilation == 1:
# convert to hashable types
kernel_size = make_tuple(kernel_size)
stride = make_tuple(stride)
dilation = make_tuple(dilation)

if kernel_size == (1, 1, 1) and stride == (1, 1, 1) and dilation == (1, 1,
1):
output_features = features.matmul(kernel)
if bias is not None:
output_features += bias
Expand All @@ -107,34 +120,37 @@ def conv3d(inputs,
output_tensor.kernel_maps = inputs.kernel_maps
output_tensor.check()
elif not transpose:
kernel_map = inputs.kernel_maps.get(
'k%s_os%d_s%d_d%d' % (kernel_size, cur_stride, stride, dilation),
None)
kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
dilation)
kernel_map = inputs.kernel_maps.get(kernel_map_key, None)

if stride > 1:
if any(x > 1 for x in stride):
# do downsample
kRegion = KernelRegion(kernel_size=kernel_size,
tensor_stride=cur_stride)
kOffset = kRegion.get_kernel_offset().to(features.device)
new_coords = spdownsample(coords, stride * cur_stride)
hash_query = sphash(new_coords, kOffset)
hash_target = sphash(coords)
idx_query = sphashquery(hash_query, hash_target)
idx_query = list(convert_neighbor_map_gpu(idx_query))
new_coords = spF.spdownsample(coords, stride, kernel_size,
cur_stride)
hash_query = spF.sphash(new_coords, kOffset)
hash_target = spF.sphash(coords)
idx_query = spF.sphashquery(hash_query, hash_target)
idx_query = list(spF.squeeze_nmap(idx_query))
idx_query[1] = idx_query[1].to('cpu')
sizes = (features.shape[0], new_coords.shape[0])
output_features = sparseconv_op(features, kernel, idx_query[0],
idx_query[1], sizes, transpose)
if bias is not None:
output_features += bias
output_tensor = SparseTensor(output_features, new_coords,
cur_stride * stride)
output_tensor = SparseTensor(
output_features, new_coords,
[a * b for a, b in zip(cur_stride, stride)])
output_tensor.coord_maps = copy.deepcopy(inputs.coord_maps)
output_tensor.check()

kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
dilation)
output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps)
output_tensor.kernel_maps['k%s_os%d_s%d_d%d' %
(kernel_size, cur_stride, stride,
dilation)] = idx_query + [sizes]
output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes]

else:
if kernel_map is None:
Expand All @@ -144,10 +160,10 @@ def conv3d(inputs,
kOffset = kRegion.get_kernel_offset().to(features.device)
except:
raise
hash_query = sphash(coords, kOffset)
hash_target = sphash(coords)
idx_query = sphashquery(hash_query, hash_target)
idx_query = list(convert_neighbor_map_gpu(idx_query))
hash_query = spF.sphash(coords, kOffset)
hash_target = spF.sphash(coords)
idx_query = spF.sphashquery(hash_query, hash_target)
idx_query = list(spF.squeeze_nmap(idx_query))
idx_query[1] = idx_query[1].to('cpu')
sizes = (features.shape[0], features.shape[0])
output_features = sparseconv_op(features, kernel, idx_query[0],
Expand All @@ -159,9 +175,9 @@ def conv3d(inputs,
output_tensor.coord_maps = inputs.coord_maps
output_tensor.check()
output_tensor.kernel_maps = copy.deepcopy(inputs.kernel_maps)
output_tensor.kernel_maps['k%s_os%d_s%d_d%d' %
(kernel_size, cur_stride, stride,
dilation)] = idx_query + [sizes]
kernel_map_key = KernelMapKey(kernel_size, cur_stride, stride,
dilation)
output_tensor.kernel_maps[kernel_map_key] = idx_query + [sizes]
else:
output_features = sparseconv_op(features, kernel,
kernel_map[0], kernel_map[1],
Expand All @@ -176,17 +192,24 @@ def conv3d(inputs,

else:
# do upsample
original_stride = int(cur_stride / stride)
kernel_map = inputs.kernel_maps.get(
'k%s_os%d_s%d_d%d' %
(kernel_size, original_stride, stride, dilation), None)

original_stride = tuple(
[int(a / b) for a, b in zip(cur_stride, stride)])

kernel_map_key = KernelMapKey(kernel_size, original_stride, stride,
dilation)
kernel_map = inputs.kernel_maps.get(kernel_map_key, None)
assert kernel_map is not None, f'{kernel_map_key} does not exist.'
output_features = sparseconv_op(features, kernel, kernel_map[0],
kernel_map[1], kernel_map[2],
transpose)
if bias is not None:
output_features += bias
output_tensor = SparseTensor(output_features,
inputs.coord_maps[original_stride],

cur_coords = inputs.coord_maps.get(original_stride, None)
assert cur_coords is not None, f'{original_stride} not in coord maps.'

output_tensor = SparseTensor(output_features, cur_coords,
original_stride)
output_tensor.coord_maps = inputs.coord_maps
output_tensor.check()
Expand Down
27 changes: 0 additions & 27 deletions torchsparse/nn/functional/convert_neighbor_map.py

This file was deleted.

97 changes: 59 additions & 38 deletions torchsparse/nn/functional/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,69 @@
import torchsparse_backend
from torch.autograd import Function
from torchsparse.nn.functional.hash import *
from torchsparse.nn.functional.voxelize import spvoxelize
from torchsparse.utils.kernel import KernelRegion
from typing import Tuple, List, Union

__all__ = ['spdownsample']


class DownsampleGPU(Function):
@staticmethod
def forward(ctx, coords, ratio):
coords_float = coords[:, :3].float()
# following Minkowski engine
coords_new = torch.floor(torch.floor(coords_float / ratio) *
ratio).int()
coords_new = torch.cat([coords_new, coords[:, 3].view(-1, 1)], 1)
coords_new_hash = sphash(coords_new)
uq, inv, cnt = torch.unique(coords_new_hash,
return_inverse=True,
return_counts=True)
inv = inv.int()
cnt = cnt.int()
# rounding is necessary
# gpu
if 'cuda' in str(coords.device):
uq_coords = torch.round(spvoxelize(coords_new.float(), inv,
cnt))
elif 'cpu' in str(coords.device):
uq_coords = torch.round(
torchsparse_backend.cpu_insertion_forward(
coords_new.float(), inv, cnt))
else:
device = coords.device
uq_coords = torch.round(
torchsparse_backend.cpu_insertion_forward(
coords_new.float().cpu(), inv.cpu(), cnt.cpu()))
uq_coords = uq_coords.to(device)
uq_coords = uq_coords.int()

# Notice: corrds_new_hash cannot be directly used
return uq_coords #, coords_new_hash

def spdownsample(
coords: torch.Tensor,
ratio: Union[int, List[int], Tuple[int, int, int]] = 2,
kernel_size: Union[int, List[int], Tuple[int, int, int]] = 2,
tensor_stride: Union[int, List[int], Tuple[int, int, int]] = 1
) -> torch.Tensor:

downsample_gpu = DownsampleGPU.apply
if not isinstance(ratio, int):
ratio = torch.IntTensor(ratio).to(coords.device).unsqueeze(0)
if not isinstance(tensor_stride, int):
tensor_stride = torch.IntTensor(tensor_stride).to(
coords.device).unsqueeze(0)

if isinstance(kernel_size, int) and isinstance(ratio, int):
direct_downsample = kernel_size == ratio
else:
if isinstance(kernel_size, int):
# ratio is a permutation of [1, 1, kernel_size]
direct_downsample = (kernel_size == ratio.prod().item()) & \
(torch.sum(ratio == kernel_size) == 1).item()
else:
direct_downsample = False

def spdownsample(coords, ratio):
return downsample_gpu(coords, ratio)
if direct_downsample:
_ratio = ratio * tensor_stride
new_coords = torch.cat(
[coords[:, :3] // _ratio * _ratio, coords[:, 3:]], 1)
return torch.unique(new_coords, dim=0)
else:
kernel_region = KernelRegion(kernel_size, tensor_stride, dilation=1)
# kernel volume x 3
kernel_offset = kernel_region.get_kernel_offset().to(coords.device)
new_coords = coords[:, :3].unsqueeze(1).repeat(
1, kernel_offset.size(0), 1) + kernel_offset
# (N x kernel volume) x 4
new_coords = torch.cat([
coords[:, 3:].repeat(1, kernel_offset.size(0)).view(-1, 1),
new_coords.view(-1, 3)
],
dim=1)
new_ts = tensor_stride * ratio
# only keep these coordinates that is multiple of new_ts.
if isinstance(new_ts, torch.Tensor):
new_ts = new_ts[0]
new_coords = new_coords[
(new_coords[:, 1] % new_ts[0].item() == 0) & (new_coords[:, 2] % new_ts[1].item() == 0) & \
(new_coords[:, 3] % new_ts[2].item() == 0)
]
else:
new_coords = new_coords[
(new_coords[:, 1] % new_ts == 0) & (new_coords[:, 2] % new_ts == 0) & \
(new_coords[:, 3] % new_ts == 0)
]
new_coords = new_coords[(new_coords[:, 1] >= 0)
& (new_coords[:, 2] >= 0) &
(new_coords[:, 3] >= 0)]
# filter out duplicates
new_coords = torch.unique(new_coords, dim=0)
new_coords = new_coords[:, [1, 2, 3, 0]]
return new_coords
12 changes: 12 additions & 0 deletions torchsparse/nn/functional/squeeze_nmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch

__all__ = ['squeeze_nmap']


def squeeze_nmap(neighbor_map: torch.Tensor) -> torch.Tensor:
idx_batch, idx_point = torch.where(neighbor_map != -1)
map_converted = neighbor_map.view(-1)[idx_batch * neighbor_map.size(1) +
idx_point]
map_converted = torch.stack([map_converted, idx_point], dim=1)
nmap_offset = torch.sum(neighbor_map != -1, 1)
return map_converted.int().contiguous(), nmap_offset.int().contiguous()
Loading