From 839a9b7226822c4a9d66f43c3052c6064160889b Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 9 Oct 2020 16:15:26 -0700 Subject: [PATCH] Faster sparse_dense on GPUs (#6580) * Faster sparse_dense on GPUs. This new sparse_dense requires a padded matrix, so a new op `sparse_dense_padded` has been added. AlterOpLayout should transform `sparse_dense` to `sparse_dense_padded` when possible on the gpu. * formatting * more formatting * Check that alteroplayout is definedbefore using it * check if FTVMAlterOpLayout exists before using it * formatting * restore message passing * Fix sparse_dense and sparse_dense_padded docs * Fix old sparse_dense, autotvm and sparse_dense dont play well together * Remove unused imports * clarify warp count in cuda_transpose * Document multidimensional access * Warn users not to use sparse_dense_padded * rename nn.sparse_dense_padded to nn.internal.sparse_dense_padded --- python/tvm/relay/op/nn/_nn.py | 16 + python/tvm/relay/op/nn/nn.py | 6 +- python/tvm/relay/op/strategy/cuda.py | 13 + python/tvm/relay/op/strategy/generic.py | 6 + python/tvm/tir/ir_builder.py | 39 ++- python/tvm/topi/cuda/sparse.py | 310 ++++++++++++++++++- python/tvm/topi/nn/sparse.py | 25 ++ src/relay/op/nn/sparse.cc | 36 ++- src/relay/transforms/transform_layout.h | 12 + src/target/source/codegen_cuda.cc | 3 +- src/te/operation/compute_op.cc | 4 +- src/te/operation/op_util.cc | 3 +- src/te/schedule/schedule_lang.cc | 2 +- src/tir/transforms/lower_warp_memory.cc | 6 +- src/tir/transforms/storage_access.cc | 2 +- tests/python/topi/python/test_topi_sparse.py | 114 +++++-- 16 files changed, 537 insertions(+), 60 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c83f6a943a31..9e47dc0a17f1 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -75,6 +75,22 @@ def compute_sparse_dense(attrs, inputs, out_type): reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_alter_op_layout("nn.sparse_dense") +def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type): + """Alternate the layout of sparse_dense""" + return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type) + + +@reg.register_compute("nn.internal.sparse_dense_padded") +def compute_sparse_dense_padded(attrs, inputs, out_type): + """Compute definition of sparse_dense_padded""" + raise NotImplementedError("nn.internal.sparse_dense_padded is only available on cuda") + + +reg.register_strategy("nn.internal.sparse_dense_padded", strategy.sparse_dense_padded_strategy) +reg.register_pattern("nn.internal.sparse_dense_padded", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + # sparse_transpose @reg.register_compute("nn.sparse_transpose") def compute_sparse_transpose(attrs, inputs, out_type): diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 86a76ff28fa5..1aad4e7125fd 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2016,7 +2016,7 @@ def sparse_dense(data, weight): data : tvm.relay.Expr The input data for the matrix multiplication - weight : namedtuple. + weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]]. The sparse weight matrix for the matrix multiplication. Returns @@ -2024,7 +2024,9 @@ def sparse_dense(data, weight): result: tvm.relay.Expr The computed result. """ - return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr) + if hasattr(weight, "indices"): + return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr) + return _make.sparse_dense(data, weight[0], weight[1], weight[2]) def sparse_transpose(x): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index baa03f49924d..7031365251aa 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -633,6 +633,19 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target): return strategy +@sparse_dense_padded_strategy.register(["cuda", "gpu"]) +def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): + """sparse dense cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_dense(topi.cuda.sparse_dense_padded), + wrap_topi_schedule(topi.cuda.schedule_sparse_dense_padded), + name="sparse_dense_padded.cuda", + plevel=10, + ) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 56ae97652b79..0f9971012f3c 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -724,6 +724,12 @@ def sparse_dense_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("sparse_dense_padded_strategy") +def sparse_dense_padded_strategy(attrs, inputs, out_type, target): + """sparse dense padded generic strategy""" + raise NotImplementedError("sparse_dense_padded is only implemented for cuda") + + # sparse_transpose @generic_func def schedule_sparse_transpose(attrs, outs, target): diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 8b999bf2bf52..77fe79b327b6 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -42,6 +42,9 @@ class BufferVar(ObjectGeneric): Do not create it directly, create use IRBuilder. + BufferVars support array access either via a linear index, or, if given a + shape, via a multidimensional index. + Examples -------- In the follow example, x is BufferVar. @@ -55,6 +58,12 @@ class BufferVar(ObjectGeneric): x = ib.pointer("float32") x[0] = x[10] + 1 + y = ib.allocate("float32", (32, 32)) + # Array access using a linear index + y[(2*32) + 31] = 0. + # The same array access using a multidimensional index + y[2, 31] = 0. + See Also -------- IRBuilder.pointer @@ -62,9 +71,10 @@ class BufferVar(ObjectGeneric): IRBuilder.allocate """ - def __init__(self, builder, buffer_var, content_type): + def __init__(self, builder, buffer_var, shape, content_type): self._builder = builder self._buffer_var = buffer_var + self._shape = shape self._content_type = content_type def asobject(self): @@ -74,8 +84,23 @@ def asobject(self): def dtype(self): return self._content_type + def _linear_index(self, index): + if not isinstance(index, tuple) or self._shape is None: + return index + assert len(index) == len(self._shape), "Index size (%s) does not match shape size (%s)" % ( + len(index), + len(self._shape), + ) + dim_size = 1 + lidx = 0 + for dim, idx in zip(reversed(self._shape), reversed(index)): + lidx += idx * dim_size + dim_size *= dim + return lidx + def __getitem__(self, index): t = DataType(self._content_type) + index = self._linear_index(index) if t.lanes > 1: base = index * t.lanes index = _expr.Ramp(base, const(1, base.dtype), t.lanes) @@ -87,6 +112,7 @@ def __setitem__(self, index, value): raise ValueError( "data type does not match content type %s vs %s" % (value.dtype, self._content_type) ) + index = self._linear_index(index) t = DataType(self._content_type) if t.lanes > 1: base = index * t.lanes @@ -341,7 +367,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): if scope: self.scope_attr(buffer_var, "storage_scope", scope) self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) - return BufferVar(self, buffer_var, dtype) + return BufferVar(self, buffer_var, shape, dtype) def pointer(self, content_type, name="ptr"): """Create pointer variable with content type. @@ -360,9 +386,9 @@ def pointer(self, content_type, name="ptr"): The buffer var representing the buffer. """ buffer_var = _expr.Var(name, dtype="handle") - return BufferVar(self, buffer_var, content_type) + return BufferVar(self, buffer_var, None, content_type) - def buffer_ptr(self, buf): + def buffer_ptr(self, buf, shape=None): """Create pointer variable corresponds to buffer ptr. Parameters @@ -370,12 +396,15 @@ def buffer_ptr(self, buf): buf : Buffer The buffer to be extracted. + shape : Tuple + Optional shape of the buffer. Overrides existing buffer shape. + Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - return BufferVar(self, buf.data, buf.dtype) + return BufferVar(self, buf.data, buf.shape if shape is None else shape, buf.dtype) def likely(self, expr): """Add likely tag for expression. diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index d1d31a6d004d..3fd6fbebc62f 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -16,15 +16,17 @@ # under the License. """Sparse operators""" -from tvm import te -from tvm import autotvm -from tvm.autotvm.task.space import SplitEntity -from ..util import traverse_inline +import numpy as np +import scipy.sparse as sp + +import tvm +from tvm import relay, te + from .. import nn +from ..util import traverse_inline -@autotvm.register_topi_compute("sparse_dense.cuda") -def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr): +def sparse_dense(data, weight_data, weight_indices, weight_indptr): """ Computes sparse-dense matrix multiplication of `data` and `(weight_data, weight_indices, weight_indptr).T` @@ -58,8 +60,7 @@ def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr): return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr) -@autotvm.register_topi_schedule("sparse_dense.cuda") -def schedule_sparse_dense(cfg, outs): +def schedule_sparse_dense(outs): """Create schedule for sparse dense""" # pylint:disable=invalid-name s = te.create_schedule([x.op for x in outs]) @@ -83,12 +84,7 @@ def _callback(op): thread_x = te.thread_axis("threadIdx.x") - cfg.define_split("tile_c", c, num_outputs=2) - if cfg.is_fallback: - cfg["tile_c"] = SplitEntity([-1, 8]) - _, ci = cfg["tile_c"].apply(s, y_bsrmm, c) - - y_bsrmm_factored = s.rfactor(y_bsrmm, ci) + y_bsrmm_factored = s.rfactor(y_bsrmm, c) tx = s[y_bsrmm].op.reduce_axis[0] s[y_bsrmm].bind(tx, thread_x) s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx) @@ -97,3 +93,289 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + + +def schedule_cuda_transpose(s, out): + """Schedule for transpose on the gpu. + + Roughly follows this: + https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but + without the padding for shared memory. For better performance, we could + rewrite it in tir to add the padding. + """ + + def _callback(op): + # pylint: disable=invalid-name + m, n = s[op].op.axis + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) + no, ni = s[op].split(n, factor=warp_size) + mo, mi = s[op].split(m, factor=warp_size) + s[op].reorder(mo, no, mi, ni) + s[op].bind(mo, te.thread_axis("blockIdx.x")) + s[op].bind(no, te.thread_axis("blockIdx.y")) + c = s.cache_read(op.input_tensors[0], "shared", op) + s[c].compute_at(s[op], no) + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") + s[op].bind(ni, thread_x) + # This is a hack to make the scheduling language realize that this axis + # can be scheduled. + a, _ = s[c].split(s[c].op.axis[1], factor=1) + s[c].bind(a, thread_x) + # Use 4 warps per block. Slightly faster than 1 warp per block + ao, _ = s[op].split(mi, nparts=4) + s[op].bind(ao, thread_y) + ao, _ = s[c].split(s[c].op.axis[0], nparts=4) + s[c].bind(ao, thread_y) + + traverse_inline(s, out.op, _callback) + + +def sparse_dense_tir(data, w_data, w_indices, w_indptr): + """Compute data * w^T. + + Actually computes (w * data^T) ^ T as data needs to be in column-major + format for performance reasons. + + Good resources: + Yang, Carl, Aydın Buluç, and John D. Owens. "Design principles for sparse + matrix multiplication on the GPU." European Conference on Parallel + Processing. Springer, Cham, 2018. <- This code is basically row-split from here. + Gale, Trevor, et al. "Sparse GPU Kernels for Deep Learning." arXiv preprint + arXiv:2006.10901 (2020). + + + Profile with + `/opt/nvidia/nsight-compute/2020.1.2/ncu -k default_function_kernel1 + --section '.*' -s 1 -c 1 venv/bin/python3 test_topi_sparse.py manual` + with either default_function_kernel0 for the transpose or + default_function_kernel1 for the multiply. + """ + + def ceil_div(a, b): + return (a + (b - 1)) // b + + def gen_ir(data, w_data, w_indices, w_indptr, out): + # pylint: disable=invalid-name + # TODO(tkonolige): use tensorcores for block multiply + # TODO(tkonolige): use vectorize on loads + # TODO(tkonolige): seperate implementation if M is small + # TODO(tkonolige): seperate implementation for large block sizes + ib = tvm.tir.ir_builder.create() + + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) + m = data.shape[1] + nb = w_indptr.shape[0] - 1 + nnzb = w_data.shape[0] + # treat csr like block size 1 bsr + if len(w_data.shape) == 1: + bs_n = 1 + bs_k = 1 + else: + bs_n = w_data.shape[1] + bs_k = w_data.shape[2] + bs_m = bs_n + mb = m // bs_m + mi = warp_size + assert ( + mb >= mi + ), "Number of block rows in dense matrix must be larger than warp size: {} vs {}.".format( + warp_size, m + ) + mo = ceil_div(mb, mi) + ni = 1 # TODO(tkonolige): how do I compute the number of warps per block? + no = ceil_div(nb, ni) + rowlength_bi = warp_size + + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", mo) + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", no) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + warp = te.thread_axis("threadIdx.y") + ib.scope_attr(warp, "thread_extent", ni) + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + w_data_ptr = ib.buffer_ptr(w_data, shape=(nnzb, bs_n, bs_k)) + w_indices_ptr = ib.buffer_ptr(w_indices) + w_indptr_ptr = ib.buffer_ptr(w_indptr) + + n_index = by * ni + warp + m_index = bx * mi + tx + row_start = w_indptr_ptr[n_index] + + # Guaranteed to be evenly divisible + rowlength_bo = ceil_div(w_indptr_ptr[n_index + 1] - row_start, rowlength_bi) + + # thread local storage for bs_m x bs_n block + block = ib.allocate(data.dtype, (bs_m, bs_n), name="block", scope="local") + indices = ib.allocate(w_indices.dtype, (rowlength_bi,), name="indices", scope="warp") + data_cache = ib.allocate(data.dtype, (mi, bs_m, bs_k), name="data_cache", scope="local") + w_data_cache = ib.allocate( + w_data.dtype, (rowlength_bi, bs_n, bs_k), name="w_data_cache", scope="warp" + ) + + # zero block + with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: + with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: + block[x, y] = 0.0 + # compute into thread local storage using warp_size chunks + with ib.for_range(0, rowlength_bo, name="bb") as bb: + elem_idx = bb * rowlength_bi + tx + # Cache indices. Guaranteed to be multiple of warp_size. + indices[elem_idx] = w_indices_ptr[row_start + elem_idx] + # cache dense matrix + # each thread has a row + # TODO: ideally we could vectorize this + with ib.for_range(0, rowlength_bi, name="bi") as bi: + with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: + with ib.for_range(0, bs_k, name="z", for_type="unroll") as z: + # This memory acces should be out of bounds when + # m_index >= mb (which occurs when the dense matrix + # rows % 32 != 0), but it seems to work just fine... + data_cache[bi, x, z] = data_ptr[indices[bi] * bs_k + z, m_index * bs_m + x] + # cache w_data + elem_idx = bb * rowlength_bi + tx + with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: + with ib.for_range(0, bs_k, name="z", for_type="unroll") as z: + w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] + with ib.for_range(0, mi, name="i") as i: + # thread local block matmul + with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: + with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: + with ib.for_range(0, bs_k, name="z", for_type="unroll") as z: + block[x, y] += data_cache[i, x, z] * w_data_cache[i, y, z] + # store results + with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: + with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: + with ib.if_scope(m_index < mb): + with ib.if_scope(n_index < nb): + # It doesn't seem like we would be getting coelesced + # writes here, but it doesn't seem to matter + out_ptr[m_index * bs_m + x, n_index * bs_n + y] = block[x, y] + + return ib.get() + + data_t = tvm.topi.transpose(data) + # handle csr + if len(w_data.shape) == 1: + blocksize = 1 + else: + blocksize = w_data.shape[1] + out_shape = (data_t.shape[1], (w_indptr.shape[0] - 1) * blocksize) + out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") + out = te.extern( + [out_shape], + [data_t, w_data, w_indices, w_indptr, data], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="sparse_dense_gpu", + tag="sparse_dense_gpu", + ) + return out + + +def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr): + """ + Computes sparse-dense matrix multiplication of `data` and + `(weight_data, weight_indices, weight_indptr).T` + + This variation uses a padded matrix where all row lengths are a multiple of the warp size. + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.te.Tensor + 2-D with shape [M, K], float32 + + weight_data : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 3-D with shape [num_blocks, bs_r, bs_c] (BSR) + + weight_indices : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 1-D with shape [num_blocks] (BSR) + + weight_indptr : tvm.te.Tensor + 1-D with shape [N + 1] (CSR) or + 1-D with shape [(N + 1) // bs_r] (BSR) + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [M, N] + """ + return sparse_dense_tir(data, weight_data, weight_indices, weight_indptr) + + +def schedule_sparse_dense_padded(outs): + """Create schedule for sparse dense""" + # XXX: this will fail if we don't include the data_t Tensor in the schedule + # ops. Maybe create_schedule should do some analysis so this isn't + # necessary + data_t = outs[0].op.input_tensors[0] + s = te.create_schedule([outs[0].op, data_t.op]) + schedule_cuda_transpose(s, outs[0].op.input_tensors[0]) + return s + + +def pad_sparse_matrix(matrix, blocksize): + """Pad rows of sparse matrix matrix so that they are a multiple of blocksize.""" + assert isinstance(matrix, sp.bsr_matrix) + new_entries = np.zeros(matrix.shape[0], dtype=matrix.indptr.dtype) + bsr = matrix.blocksize[0] + for i in range(matrix.shape[0] // bsr): + row_length = matrix.indptr[i + 1] - matrix.indptr[i] + if row_length % blocksize != 0: + new_entries[i] = blocksize - (row_length % blocksize) + additional = np.sum(new_entries) + indices = np.zeros(matrix.indices.shape[0] + additional, dtype=matrix.indices.dtype) + data = np.zeros( + (matrix.data.shape[0] + additional, matrix.data.shape[1], matrix.data.shape[2]), + dtype=matrix.data.dtype, + ) + + n = matrix.shape[0] // bsr + indptr = np.zeros(n + 1, dtype=matrix.indptr.dtype) + indptr[: matrix.indptr.shape[0]] = matrix.indptr + + for i in range(matrix.shape[0] // bsr): + indptr[i + 1] = indptr[i] + new_entries[i] + (matrix.indptr[i + 1] - matrix.indptr[i]) + indices[indptr[i] : indptr[i + 1] - new_entries[i]] = matrix.indices[ + matrix.indptr[i] : matrix.indptr[i + 1] + ] + data[indptr[i] : indptr[i + 1] - new_entries[i], :, :] = matrix.data[ + matrix.indptr[i] : matrix.indptr[i + 1], :, : + ] + + return sp.bsr_matrix((data, indices, indptr), matrix.shape) + + +@nn.sparse_dense_alter_layout.register(["cuda", "gpu"]) +def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type): + """With cuda, we modify use alter_op_layout to swap the default + sparse_dense implementation for one that operates on a padded matrix. We + also padd the matrix. + """ + if ( + isinstance(inputs[1], relay.Constant) + and isinstance(inputs[2], relay.Constant) + and isinstance(inputs[3], relay.Constant) + ): + sparse_matrix = sp.bsr_matrix( + (inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy()) + ) + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) + sparse_matrix = pad_sparse_matrix(sparse_matrix, warp_size) + return relay.nn._make.sparse_dense_padded( + inputs[0], + relay.Constant(tvm.nd.array(sparse_matrix.data)), + relay.Constant(tvm.nd.array(sparse_matrix.indices)), + relay.Constant(tvm.nd.array(sparse_matrix.indptr)), + ) + return None diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index e3c144adb768..74a9ad5fd650 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -207,3 +207,28 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): last[0] = temp2[0] return irb.get() + + +@tvm.target.generic_func +def sparse_dense_alter_layout(_attrs, _inputs, _tinfos, _out_type): + """Change Sparse Dense layout. + + This is used for modifying the inputs weights so they are more amenable for + the target. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + out_type: type + The output type + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level. + """ + return None diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index 0aca00ce80a4..f12afe2a7f1f 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -76,7 +76,41 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") }); RELAY_REGISTER_OP("nn.sparse_dense") - .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. + .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with W sparse. + +- **data**: `(x1, x2, ..., xn, input_dim)` +- **weight**: `(units, input_dim)` +- **out**: `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight_data", "1D Tensor", "Weight data matrix.") + .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") + .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") + .set_support_level(1) + .add_type_rel("SparseDense", SparseDenseRel); + +Expr MakeSparseDensePadded(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) { + auto attrs = make_object(); + static const Op& op = Op::Get("nn.internal.sparse_dense_padded"); + return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense_padded") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeSparseDensePadded, args, rv); + }); + +RELAY_REGISTER_OP("nn.internal.sparse_dense_padded") + .describe( + R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with W +sparse. This variation uses a matrix with row lengths padded to a +multiple of 32 for better GPU performance. + +This op should not be directly used by a user. Instead, use `sparse_dense` +which will be converted to this op when running on the GPU. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index bf9bcb9a569e..61a74404afd1 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -267,6 +267,18 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj } } + // If there is no FInferCorrectLayout for the type, then we just assume the layout is correct. + static auto finfer_layout = Op::GetAttrMap("FInferCorrectLayout"); + if (Op::HasAttrMap("FTVMAlterOpLayout")) { + static auto falter_layout = Op::GetAttrMap("FTVMAlterOpLayout"); + if (ref_call->op.as()) { + Op op = Downcast(ref_call->op); + if (falter_layout.count(op) && !finfer_layout.count(op)) { + return memorizer.CallWithNewLayouts(ref_call, normal_new_args); + } + } + } + // old_in, new_in = state[inputs] Array old_in, old_out, new_in, new_out, new_in2; for (auto inp : inputs) { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7dc63d4ac949..d57efa007272 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -394,7 +394,8 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { } void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - CHECK_NE(scope, "global"); + CHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " + "all global arrays as input instead"; if (scope == "shared") { os << "__shared__ "; } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index c3b2a0b44eb1..527b251867ad 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -46,7 +46,9 @@ using namespace tir; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "compute(" << op->name << ", " << op << ")"; + p->stream << "compute(" << op->name << ", body=" << op->body << ", axis=" << op->axis + << ", reduce_axis=" << op->reduce_axis << ", tag=" << op->tag + << ", attrs=" << op->attrs << ")"; }); TVM_REGISTER_NODE_TYPE(ComputeOpNode); diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index db649e541a65..2abf68a71d54 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -150,7 +150,8 @@ std::vector > MakeLoopNest(const Stage& stage, value_map[iv] = dom->min; } else { // Always restrict threaded IterVar to starts from 0. - CHECK(is_zero(dom->min)); + CHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at " + << dom->min; // annotate the extent of the IterVar nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index d6327ffe0f08..a8257c07a473 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -761,7 +761,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); if (op->op.defined()) { - p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; + p->stream << "stage(" << op->origin_op->name << ", " << op->op << ")"; } else { p->stream << "group-stage(" << op << ")"; } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 8892c322acd8..cb6c609ef657 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -129,7 +129,11 @@ class WarpStoreCoeffFinder : private StmtVisitor { void UpdatePattern(const PrimExpr& index) { Array m = arith::DetectLinearEquation(index, {warp_index_}); - CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; + CHECK_EQ(m.size(), 2U) + << "LowerWarpMemory failed. Could not simplify the store index `" << index + << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " + "thread local registers and shuffling values between these registers. Currently only " + "linear equation indices are supported."; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); const auto* mcoeff_as_int = mcoeff.as(); CHECK(mcoeff_as_int && mcoeff_as_int->value > 0) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 1914609b348d..f9adfb82a33f 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -37,7 +37,7 @@ void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); StorageScope scope = GetScope(buf); if (Enabled(buf, scope)) { - CHECK(allow_append_); + CHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; e.threads = env_threads(); e.buffer = op->buffer_var; diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index b50110ab768e..07af478a5087 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -19,6 +19,7 @@ import tvm from tvm import te from tvm import topi +from tvm import relay import tvm.topi.testing from tvm.topi.util import get_const_tuple import tvm.contrib.sparse as tvmsp @@ -329,11 +330,11 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): return s -def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu): +def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu, ctx, target): X_np = np.random.randn(M, K).astype("float32") W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") W_np = W_sp_np.todense() - Y_np = X_np.dot(W_np.T) + Y_np = X_np @ W_np.T if use_relu: Y_np = np.maximum(Y_np, 0.0) @@ -342,38 +343,29 @@ def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu): W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) - def check_device(device): - ctx = tvm.context(device, 0) - if not tvm.testing.device_enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - fcompute, fschedule = tvm.topi.testing.dispatch(device, _sparse_dense_implement) - with tvm.target.Target(device): - Y = fcompute(X, W_data, W_indices, W_indptr) - if use_relu: - Y = topi.nn.relu(Y) - s = fschedule([Y]) - func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) - func( - tvm.nd.array(X_np, ctx=ctx), - tvm.nd.array(W_sp_np.data, ctx=ctx), - tvm.nd.array(W_sp_np.indices, ctx=ctx), - tvm.nd.array(W_sp_np.indptr, ctx=ctx), - Y_tvm, - ) - tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) - - for device in ["llvm", "cuda"]: - check_device(device) + fcompute, fschedule = tvm.topi.testing.dispatch(target, _sparse_dense_implement) + with tvm.target.Target(target): + Y = fcompute(X, W_data, W_indices, W_indptr) + if use_relu: + Y = topi.nn.relu(Y) + s = fschedule([Y]) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + func( + tvm.nd.array(X_np, ctx=ctx), + tvm.nd.array(W_sp_np.data, ctx=ctx), + tvm.nd.array(W_sp_np.indices, ctx=ctx), + tvm.nd.array(W_sp_np.indptr, ctx=ctx), + Y_tvm, + ) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) -@tvm.testing.uses_gpu -def test_sparse_dense_bsr(): +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_sparse_dense_bsr_relu(ctx, target): M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9 - verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True) - verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False) + verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, True, ctx, target) + verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, False, ctx, target) @tvm.testing.uses_gpu @@ -421,11 +413,69 @@ def check_device(device): check_device(device) +@tvm.testing.requires_cuda +def test_sparse_dense_padded_cuda(): + M = 128 + N = 1280 + K = 128 + X_np = np.random.randn(M, K).astype("float32") + W_sp_np = random_bsr_matrix(N, K, 1, 1, density=0.01, dtype="float32") + W_sp_np_padded = tvm.topi.cuda.pad_sparse_matrix(W_sp_np, 32) + + W_np = W_sp_np.todense() + Y_np = X_np @ W_sp_np.T + + W_data = te.placeholder(shape=W_sp_np_padded.data.shape, dtype=str(W_sp_np_padded.data.dtype)) + W_indices = te.placeholder( + shape=W_sp_np_padded.indices.shape, dtype=str(W_sp_np_padded.indices.dtype) + ) + W_indptr = te.placeholder( + shape=W_sp_np_padded.indptr.shape, dtype=str(W_sp_np_padded.indptr.dtype) + ) + X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) + with tvm.target.Target("cuda"): + ctx = tvm.context("gpu") + Y = topi.cuda.sparse_dense_padded(X, W_data, W_indices, W_indptr) + s = topi.cuda.schedule_sparse_dense_padded([Y]) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + func( + tvm.nd.array(X_np, ctx=ctx), + tvm.nd.array(W_sp_np_padded.data, ctx=ctx), + tvm.nd.array(W_sp_np_padded.indices, ctx=ctx), + tvm.nd.array(W_sp_np_padded.indptr, ctx=ctx), + Y_tvm, + ) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) + + +@tvm.testing.requires_cuda +def test_sparse_dense_padded_alter_op(): + with tvm.target.Target("cuda"): + M = 128 + N = 16 + K = 128 + X_np = np.random.randn(M, K).astype("float32") + W_sp_np = random_bsr_matrix(N, K, 2, 2, density=0.01, dtype="float32") + mult = relay.op.nn.sparse_dense( + relay.Constant(tvm.nd.array(X_np)), + ( + relay.Constant(tvm.nd.array(W_sp_np.data)), + relay.Constant(tvm.nd.array(W_sp_np.indices)), + relay.Constant(tvm.nd.array(W_sp_np.indptr)), + ), + ) + f = relay.Function([], mult) + f_ = relay.transform.AlterOpLayout()(tvm.IRModule.from_expr(f)) + assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded" + + if __name__ == "__main__": test_csrmv() test_csrmm() test_dense() test_sparse_dense_csr() - test_sparse_dense_bsr() test_sparse_dense_bsr_randomized() test_sparse_transpose_csr() + test_sparse_dense_padded_cuda() + test_sparse_dense_padded_alter_op()