Skip to content

Commit

Permalink
Faster sparse_dense on GPUs (apache#6580)
Browse files Browse the repository at this point in the history
* Faster sparse_dense on GPUs.

This new sparse_dense requires a padded matrix, so a new op
`sparse_dense_padded` has been added. AlterOpLayout should transform
`sparse_dense` to `sparse_dense_padded` when possible on the gpu.

* formatting

* more formatting

* Check that alteroplayout is definedbefore using it

* check if FTVMAlterOpLayout exists before using it

* formatting

* restore message passing

* Fix sparse_dense and sparse_dense_padded docs

* Fix old sparse_dense, autotvm and sparse_dense dont play well together

* Remove unused imports

* clarify warp count in cuda_transpose

* Document multidimensional access

* Warn users not to use sparse_dense_padded

* rename nn.sparse_dense_padded to nn.internal.sparse_dense_padded
  • Loading branch information
tkonolige authored and trevor-m committed Oct 19, 2020
1 parent 8b0cfb8 commit 839a9b7
Show file tree
Hide file tree
Showing 16 changed files with 537 additions and 60 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ def compute_sparse_dense(attrs, inputs, out_type):
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_alter_op_layout("nn.sparse_dense")
def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):
"""Alternate the layout of sparse_dense"""
return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type)


@reg.register_compute("nn.internal.sparse_dense_padded")
def compute_sparse_dense_padded(attrs, inputs, out_type):
"""Compute definition of sparse_dense_padded"""
raise NotImplementedError("nn.internal.sparse_dense_padded is only available on cuda")


reg.register_strategy("nn.internal.sparse_dense_padded", strategy.sparse_dense_padded_strategy)
reg.register_pattern("nn.internal.sparse_dense_padded", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# sparse_transpose
@reg.register_compute("nn.sparse_transpose")
def compute_sparse_transpose(attrs, inputs, out_type):
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,15 +2016,17 @@ def sparse_dense(data, weight):
data : tvm.relay.Expr
The input data for the matrix multiplication
weight : namedtuple.
weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the matrix multiplication.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
if hasattr(weight, "indices"):
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
return _make.sparse_dense(data, weight[0], weight[1], weight[2])


def sparse_transpose(x):
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,19 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@sparse_dense_padded_strategy.register(["cuda", "gpu"])
def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_dense(topi.cuda.sparse_dense_padded),
wrap_topi_schedule(topi.cuda.schedule_sparse_dense_padded),
name="sparse_dense_padded.cuda",
plevel=10,
)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,12 @@ def sparse_dense_strategy(attrs, inputs, out_type, target):
return strategy


@override_native_generic_func("sparse_dense_padded_strategy")
def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
"""sparse dense padded generic strategy"""
raise NotImplementedError("sparse_dense_padded is only implemented for cuda")


# sparse_transpose
@generic_func
def schedule_sparse_transpose(attrs, outs, target):
Expand Down
39 changes: 34 additions & 5 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class BufferVar(ObjectGeneric):
Do not create it directly, create use IRBuilder.
BufferVars support array access either via a linear index, or, if given a
shape, via a multidimensional index.
Examples
--------
In the follow example, x is BufferVar.
Expand All @@ -55,16 +58,23 @@ class BufferVar(ObjectGeneric):
x = ib.pointer("float32")
x[0] = x[10] + 1
y = ib.allocate("float32", (32, 32))
# Array access using a linear index
y[(2*32) + 31] = 0.
# The same array access using a multidimensional index
y[2, 31] = 0.
See Also
--------
IRBuilder.pointer
IRBuilder.buffer_ptr
IRBuilder.allocate
"""

def __init__(self, builder, buffer_var, content_type):
def __init__(self, builder, buffer_var, shape, content_type):
self._builder = builder
self._buffer_var = buffer_var
self._shape = shape
self._content_type = content_type

def asobject(self):
Expand All @@ -74,8 +84,23 @@ def asobject(self):
def dtype(self):
return self._content_type

def _linear_index(self, index):
if not isinstance(index, tuple) or self._shape is None:
return index
assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % (
len(index),
len(self._shape),
)
dim_size = 1
lidx = 0
for dim, idx in zip(reversed(self._shape), reversed(index)):
lidx += idx * dim_size
dim_size *= dim
return lidx

def __getitem__(self, index):
t = DataType(self._content_type)
index = self._linear_index(index)
if t.lanes > 1:
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
Expand All @@ -87,6 +112,7 @@ def __setitem__(self, index, value):
raise ValueError(
"data type does not match content type %s vs %s" % (value.dtype, self._content_type)
)
index = self._linear_index(index)
t = DataType(self._content_type)
if t.lanes > 1:
base = index * t.lanes
Expand Down Expand Up @@ -341,7 +367,7 @@ def allocate(self, dtype, shape, name="buf", scope=None):
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
return BufferVar(self, buffer_var, dtype)
return BufferVar(self, buffer_var, shape, dtype)

def pointer(self, content_type, name="ptr"):
"""Create pointer variable with content type.
Expand All @@ -360,22 +386,25 @@ def pointer(self, content_type, name="ptr"):
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, dtype="handle")
return BufferVar(self, buffer_var, content_type)
return BufferVar(self, buffer_var, None, content_type)

def buffer_ptr(self, buf):
def buffer_ptr(self, buf, shape=None):
"""Create pointer variable corresponds to buffer ptr.
Parameters
----------
buf : Buffer
The buffer to be extracted.
shape : Tuple
Optional shape of the buffer. Overrides existing buffer shape.
Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
return BufferVar(self, buf.data, buf.dtype)
return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype)

def likely(self, expr):
"""Add likely tag for expression.
Expand Down
Loading

0 comments on commit 839a9b7

Please sign in to comment.