From 64f5d50d804a0b6f451c5202a97cf371f2869450 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 15:27:21 +0900 Subject: [PATCH 01/36] initial import --- python/tvm/topi/cuda/__init__.py | 2 +- python/tvm/topi/cuda/nms.py | 578 ++++++++++++++----- python/tvm/topi/cuda/sort.py | 27 +- python/tvm/topi/cuda/vision.py | 18 + tests/python/topi/python/test_topi_vision.py | 113 +++- 5 files changed, 581 insertions(+), 157 deletions(-) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index c2f55668d2e2..4d838db8bfba 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -43,7 +43,7 @@ from .batch_matmul_tensorcore import * from .vision import * from .ssd import * -from .nms import get_valid_counts, non_max_suppression +from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression from .rcnn import * from .scatter import * from .sort import * diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index c83dae0d3b96..6f1ee7c0e72d 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -25,6 +25,9 @@ from .sort import argsort, argsort_thrust from .scan import exclusive_scan from ..utils import ceil_div +from ..math import cast +from ..transform import reshape, expand_dims +from ..broadcast import greater def cuda_atomic_add_rule(op): @@ -265,6 +268,140 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): return [valid_count, out, out_indices] +def get_boundaries(output, box_idx): + l = tvm.te.min( + output[box_idx], + output[box_idx + 2], + ) + t = tvm.te.min( + output[box_idx + 1], + output[box_idx + 3], + ) + r = tvm.te.max( + output[box_idx], + output[box_idx + 2], + ) + b = tvm.te.max( + output[box_idx + 1], + output[box_idx + 3], + ) + return l, t, r, b + + +def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes.""" + a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) + + # Overlapping width and height + w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) + h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) + + # Overlapping area + area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return tvm.tir.Select(u <= 0.0, 0.0, area / u) + + +def _nms_loop( + ib, + batch_size, + num_anchors, + top_k, + iou_threshold, + valid_count, + get_max_output_size_func, + on_new_valid_box_func, + on_new_invalidated_box_func, + needs_bbox_check_func, + calc_overlap_func, + out_scores, + num_valid_boxes, +): + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + nthread_by = batch_size + nthread_tx = max_threads + + # Some cuda architectures have smaller limit of 32K for cudaDevAttrMaxRegistersPerBlock + # vs 64K for most GPUs. Since this kernel uses many registers (around 35), the limit will + # be exceeded with 1024 threads. + target = tvm.target.Target.current(allow_none=False) + if target.kind.name == "cuda": + if nvcc.get_target_compute_version(target) in ["3.2", "5.3", "6.2"]: + nthread_tx = 512 + + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(tx, "thread_extent", nthread_tx) + + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + num_valid_boxes_local[0] = 0 + + def nms_inner_loop(ib, i, j, nkeep): + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold + on_new_valid_box_func(ib, tx, num_valid_boxes_local[0], i, j) + num_valid_boxes_local[0] += 1 + + num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) + + with ib.for_range(0, num_iter_per_thread, name="_k") as _k: + k = j + 1 + _k * nthread_tx + tx + + with ib.if_scope( + tvm.tir.all( + k < nkeep, + out_scores[i, k] > 0, # is the box k still valid? + needs_bbox_check_func(i, j, k), + ) + ): + iou = calc_overlap_func(i, j, k) + + with ib.if_scope(iou >= iou_threshold): + # invalidate the box k + out_scores[i, k] = -1.0 + on_new_invalidated_box_func(i, k) + + ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + + i = by + + max_output_size = get_max_output_size_func(i) + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + # Apply nms + with ib.if_scope(max_output_size > 0): + # No need to do more iteration if we have already reached max_output_size boxes + box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx[0] = 0 + with ib.while_loop( + tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) + ): + # Proceed to the inner loop if the box with id box_idx is still valid + with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): + nms_inner_loop(ib, i, box_idx[0], nkeep) + box_idx[0] += 1 + + with ib.else_scope(): + with ib.for_range(0, nkeep, name="j") as j: + # Proceed to the inner loop if the box j is still valid + with ib.if_scope(out_scores[i, j] > -1.0): + nms_inner_loop(ib, i, j, nkeep) + + with ib.if_scope(tx + 0 == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] + + return ib.get() + + def nms_ir( data, sorted_index, @@ -352,43 +489,6 @@ def nms_ir( stmt : Stmt The result IR statement. """ - - def get_boundaries(output, box_idx): - l = tvm.te.min( - output[box_idx], - output[box_idx + 2], - ) - t = tvm.te.min( - output[box_idx + 1], - output[box_idx + 3], - ) - r = tvm.te.max( - output[box_idx], - output[box_idx + 2], - ) - b = tvm.te.max( - output[box_idx + 1], - output[box_idx + 3], - ) - return l, t, r, b - - def calculate_overlap(out_tensor, box_a_idx, box_b_idx): - """Calculate overlap of two boxes.""" - a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) - b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) - - # Overlapping width and height - w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) - h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) - - # Overlapping area - area = h * w - - # total area of the figure formed by box a and box b - # except for overlapping area - u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area - return tvm.tir.Select(u <= 0.0, 0.0, area / u) - batch_size = data.shape[0] num_anchors = data.shape[1] box_data_length = data.shape[2] @@ -492,105 +592,52 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): box_indices[i * num_anchors + j] = j - with ib.new_scope(): - nthread_by = batch_size - nthread_tx = max_threads - - # Some cuda architectures have smaller limit of 32K for cudaDevAttrMaxRegistersPerBlock - # vs 64K for most GPUs. Since this kernel uses many registers (around 35), the limit will - # be exceeded with 1024 threads. - target = tvm.target.Target.current(allow_none=False) - if target.kind.name == "cuda": - if nvcc.get_target_compute_version(target) in ["3.2", "5.3", "6.2"]: - nthread_tx = 512 - - by = te.thread_axis("blockIdx.y") - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(tx, "thread_extent", nthread_tx) - - i = by + if isinstance(max_output_size, int): + max_output_size = tvm.tir.const(max_output_size) + def calc_overlap(i, j, k): + offset_j = j * 4 + offset_k = k * 4 base_bbox_idx = i * num_anchors * 4 - num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" + return calculate_overlap( + out_bboxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, ) - num_valid_boxes_local[0] = 0 - nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) - - def nms_inner_loop(ib, j): - # The box j is valid, invalidate other boxes that overlap with j above iou_threshold - - # When return_indices is False, no need to populate box_indices - if return_indices: - with ib.if_scope(tx + 0 == 0): - orig_idx = sorted_index[i * num_anchors + j] - box_indices[i, num_valid_boxes_local[0]] = indices[i, orig_idx] - - num_valid_boxes_local[0] += 1 - - offset_j = j * 4 - num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) - - with ib.for_range(0, num_iter_per_thread, name="_k") as _k: - k = j + 1 + _k * nthread_tx + tx - offset_k = k * 4 - - with ib.if_scope( - tvm.tir.all( - k < nkeep, - out_scores[i, k] > 0, # is the box k still valid? - tvm.tir.any( - force_suppress > 0, - id_index < 0, - out_class_ids[i, k] == out_class_ids[i, j], - ), - ) - ): - iou = calculate_overlap( - out_bboxes, - base_bbox_idx + offset_j, - base_bbox_idx + offset_k, - ) - with ib.if_scope(iou >= iou_threshold): - # invalidate the box k - out_scores[i, k] = -1.0 - if return_indices is False and id_index >= 0: - out_class_ids[i, k] = -1.0 - - ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) - - if isinstance(max_output_size, int): - max_output_size = tvm.tir.const(max_output_size) - - with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - # Apply nms - with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size boxes - box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") - box_idx[0] = 0 - with ib.while_loop( - tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) - ): - # Proceed to the inner loop if the box with id box_idx is still valid - with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): - nms_inner_loop(ib, box_idx[0]) - box_idx[0] += 1 - - with ib.else_scope(): - with ib.for_range(0, nkeep, name="j") as j: - # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out_scores[i, j] > -1.0): - nms_inner_loop(ib, j) - - with ib.if_scope(tx + 0 == 0): - num_valid_boxes[i] = num_valid_boxes_local[0] - - with ib.else_scope(): - num_valid_boxes[i] = 0 + def on_new_valid_box(ib, tid, num_current_valid_box, i, j): + # When return_indices is False, no need to populate box_indices + if return_indices: + with ib.if_scope(tid + 0 == 0): + orig_idx = sorted_index[i * num_anchors + j] + box_indices[i, num_current_valid_box] = indices[i, orig_idx] + + def on_new_invalidated_box(i, k): + if return_indices is False and id_index >= 0: + out_class_ids[i, k] = -1.0 + + def needs_bbox_check(i, j, k): + return tvm.tir.any( + force_suppress > 0, + id_index < 0, + out_class_ids[i, k] == out_class_ids[i, j], + ) - return ib.get() + return _nms_loop( + ib, + batch_size, + num_anchors, + top_k, + iou_threshold, + valid_count, + lambda _: max_output_size, + on_new_valid_box, + on_new_invalidated_box, + needs_bbox_check, + calc_overlap, + out_scores, + num_valid_boxes, + ) def _fetch_score_ir(data, score, axis): @@ -622,6 +669,17 @@ def _fetch_score_ir(data, score, axis): return ib.get() +def _dispatch_sort(scores, ret_type="indices"): + target = tvm.target.Target.current() + if target and ( + can_use_thrust(target, "tvm.contrib.thrust.sort") + or can_use_rocthrust(target, "tvm.contrib.thrust.sort") + ): + return argsort_thrust(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) + else: + return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) + + def _get_sorted_indices(data, data_buf, score_index, score_shape): """Extract a 1D score tensor from the packed input and do argsort on it.""" score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) @@ -639,17 +697,7 @@ def _get_sorted_indices(data, data_buf, score_index, score_shape): name="fetch_score", tag="fetch_score", ) - - target = tvm.target.Target.current() - if target and ( - can_use_thrust(target, "tvm.contrib.thrust.sort") - or can_use_rocthrust(target, "tvm.contrib.thrust.sort") - ): - sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32") - else: - sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") - - return sort_tensor + return _dispatch_sort(score_tensor) def _run_nms( @@ -910,3 +958,249 @@ def non_max_suppression( score_index, id_index, ) + + +def _get_valid_box_count(scores, score_threshold): + batch_classes, num_boxes = scores.shape + + def binary_search(ib, y, out): + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = num_boxes + + with ib.while_loop(lo[0] < hi[0]): + mid = (hi[0] + lo[0]) >> 1 + with ib.if_scope(scores[y, mid] > score_threshold): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + out[y] = lo[0] + + def searchsorted_ir(scores, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + ib.scope_attr(bx, "thread_extent", ceil_div(batch_classes, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < batch_classes): + binary_search(ib, tid, valid_count) + + return ib.get() + + scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) + + return te.extern( + [(batch_classes,)], + [scores], + lambda ins, outs: searchsorted_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf], + name="searchsorted", + tag="searchsorted", + ) + + +def _all_class_nms_ir( + boxes, + sorted_scores, + sorted_indices, + valid_count, + batch_class, + num_class, + num_anchors, + iou_threshold, + max_output_size_per_class, + box_indices, + num_valid_boxes, +): + ib = tvm.tir.ir_builder.create() + boxes = ib.buffer_ptr(boxes) + sorted_scores = ib.buffer_ptr(sorted_scores) + sorted_indices = ib.buffer_ptr(sorted_indices) + valid_count = ib.buffer_ptr(valid_count) + box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + + if isinstance(max_output_size_per_class, int): + max_output_size_per_class = tvm.tir.const(max_output_size_per_class) + + def calc_overlap(i, j, k): + offset_j = sorted_indices[i, j] * 4 + offset_k = sorted_indices[i, k] * 4 + batch_id = i // num_class + base_bbox_idx = batch_id * num_anchors * 4 + return calculate_overlap( + boxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, + ) + + def on_new_valid_box(ib, tid, num_current_valid_box, i, j): + with ib.if_scope(tid + 0 == 0): + box_indices[i, num_current_valid_box] = sorted_indices[i, j] + + def max_output_size(batch_class_index): + return max_output_size_per_class + + def on_new_invalidated_box(i, k): + pass + + def needs_bbox_check(i, j, k): + return tvm.tir.const(True) + + return _nms_loop( + ib, + batch_class, + num_anchors, + tvm.tir.IntImm("int32", -1), # top_k + iou_threshold, + valid_count, + max_output_size, + on_new_valid_box, + on_new_invalidated_box, + needs_bbox_check, + calc_overlap, + sorted_scores, + num_valid_boxes, + ) + + +def _run_all_class_nms( + boxes, sorted_scores, sorted_indices, valid_count, max_output_size_per_class, iou_threshold +): + batch, num_boxes, _ = boxes.shape + batch_class = sorted_scores.shape[0] + num_class = batch_class // batch + + boxes_buf = tvm.tir.decl_buffer(boxes.shape, boxes.dtype, "boxes_buf", data_alignment=8) + sorted_scores_buf = tvm.tir.decl_buffer( + sorted_scores.shape, sorted_scores.dtype, "sorted_scores_buf", data_alignment=8 + ) + sorted_indices_buf = tvm.tir.decl_buffer( + sorted_indices.shape, sorted_indices.dtype, "sorted_indices_buf", data_alignment=8 + ) + valid_count_buf = tvm.tir.decl_buffer( + valid_count.shape, "int32", "valid_count_buf", data_alignment=4 + ) + + return te.extern( + [(batch_class, num_boxes), (batch_class,)], + [boxes, sorted_scores, sorted_indices, valid_count], + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + outs[1], # num_valid_boxes + ), + dtype=["int32", "int32"], + in_buffers=[ + boxes_buf, + sorted_scores_buf, + sorted_indices_buf, + valid_count_buf, + ], + name="all_class_nms", + tag="all_class_nms", + ) + + +def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out): + batch_classes, num_boxes = selected_indices.shape + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + out = ib.buffer_ptr(out) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_boxes, nthread_tx) + nthread_by = batch_classes + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + + with ib.new_scope(): + idx = bx * nthread_tx + tx + idy = by + batch_id = idy // num_class + class_id = idy % num_class + with ib.if_scope(idx < num_detections[idy]): + out[row_offsets[idy] + idx, 0] = batch_id + out[row_offsets[idy] + idx, 1] = class_id + out[row_offsets[idy] + idx, 2] = selected_indices[idy, idx] + + return ib.get() + + +def _collect_selected_indices(num_class, selected_indices, num_detections, row_offsets): + batch_class, num_boxes = selected_indices.shape + + selected_indices_buf = tvm.tir.decl_buffer( + selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 + ) + num_detections_buf = tvm.tir.decl_buffer( + num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 + ) + row_offsets_buf = tvm.tir.decl_buffer( + row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 + ) + + return te.extern( + [(batch_class * num_boxes, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: _collect_selected_indices_ir(num_class, ins[0], ins[1], ins[2], outs[0]), + dtype=["int32"], + in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], + name="collect_indices", + tag="collect_indices", + ) + + +def all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold +): + batch, num_class, num_boxes = scores.shape + + scores = reshape(scores, (batch * num_class, num_boxes)) + sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") + valid_count = _get_valid_box_count(sorted_scores, score_threshold) + + selected_indices, num_detections = _run_all_class_nms( + boxes, sorted_scores, sorted_indices, valid_count, max_output_boxes_per_class, iou_threshold + ) + num_detections = expand_dims(num_detections, axis=0) + + row_offsets, num_total_detections = exclusive_scan(num_detections, return_reduction=True) + + selected_indices = _collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets + ) + + return selected_indices, num_total_detections diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 5ebd3060a6bb..93e4d3feccc7 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -739,7 +739,7 @@ def sort_thrust(data, axis=-1, is_ascend=1): return out -def argsort(data, axis=-1, is_ascend=1, dtype="float32"): +def argsort(data, axis=-1, is_ascend=1, dtype="float32", ret_type="indices"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -757,6 +757,11 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): dtype : string, optional DType of the output indices. + ret_type : string, optional + The return type [both, indices]. + "both": return both sorted data and indices. + "indices": return sorted indices only. + Returns ------- out : tvm.te.Tensor @@ -774,7 +779,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) - out = te.extern( + outs = te.extern( [data.shape, data.shape, data.shape, data.shape], [data], lambda ins, outs: sort_ir( @@ -789,16 +794,19 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): out_buffers=[value_buf, indices_buf, value_swap_buf, indices_swap_buf], name="argsort_gpu", tag="argsort_gpu", - )[1] + ) if axis != ndim - 1: axes = swap(list(range(ndim)), axis) - out = transpose(out, axes) + outs = [transpose(out, axes) for out in outs] - return out + if ret_type == "indices": + return outs[1] + + return outs[0], outs[1] -def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"): +def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32", ret_type="indices"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -816,12 +824,17 @@ def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"): dtype : string, optional DType of the output indices. + ret_type : string, optional + The return type [both, indices]. + "both": return both sorted data and indices. + "indices": return sorted indices only. + Returns ------- out : tvm.te.Tensor The output of this function. """ - return topk_thrust(data, 0, axis, "indices", is_ascend, dtype) + return topk_thrust(data, 0, axis, ret_type, is_ascend, dtype) def schedule_sort(outs): diff --git a/python/tvm/topi/cuda/vision.py b/python/tvm/topi/cuda/vision.py index 73b24deb35ae..481c6bdbb62d 100644 --- a/python/tvm/topi/cuda/vision.py +++ b/python/tvm/topi/cuda/vision.py @@ -78,6 +78,24 @@ def schedule_nms(outs): return _default_schedule(outs) +def schedule_all_class_non_max_suppression(outs): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def traverse(op): + if tag.is_injective(op.tag): + schedule_injective_from_existing(s, op.output(0)) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + + for out in outs: + traverse(out.op) + return s + + def schedule_multibox_prior(outs): """Schedule for multibox_prior operator. diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 7f8712c55fd1..f289de7586e4 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -65,6 +65,10 @@ "gpu": (topi.cuda.proposal, topi.cuda.schedule_proposal), } +_all_class_nms_implement = { + "gpu": (topi.cuda.all_class_non_max_suppression, topi.cuda.schedule_all_class_non_max_suppression), +} + def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): dtype = "float32" @@ -623,11 +627,106 @@ def test_proposal(): verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) +def verify_all_class_non_max_suppression( + boxes_np, scores_np, max_output_boxes_per_class, iou_threshold, score_threshold +): + dshape = boxes_np.shape + batch, num_boxes, _ = dshape + _, num_class, _ = scores_np.shape + boxes = te.placeholder(dshape, name="boxes") + scores = te.placeholder(scores_np.shape, dtype="float32", name="scores") + + def check_device(target): + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + return + print("Running on target: %s" % target) + with tvm.target.Target(target): + fcompute, fschedule = tvm.topi.testing.dispatch(target, _all_class_nms_implement) + out = fcompute( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + ) + s = fschedule(out) + + tvm_boxes = tvm.nd.array(boxes_np, dev) + tvm_scores = tvm.nd.array(scores_np, dev) + selected_indices = tvm.nd.array(np.zeros((batch * num_class * num_boxes, 3), "int32"), dev) + num_detections = tvm.nd.array(np.zeros((1,), "int32"), dev) + + f = tvm.build(s, [boxes, scores, out[0], out[1]], target) + f(tvm_boxes, tvm_scores, selected_indices, num_detections) + print(selected_indices.asnumpy()[:num_detections.asnumpy()[0]]) + # tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) + + for target in ["vulkan"]: + check_device(target) + + +@tvm.testing.uses_gpu +def test_all_class_non_max_suppression(): + boxes = np.array( + [ + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.0, 0.0, 0.5, 0.5], + [0.5, 0.5, 0.9, 0.9], + [0.5, 0.5, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.5, 0.5, 0.95, 0.95], + [0.5, 0.5, 0.96, 0.96], + [0.5, 0.5, 1.0, 1.0], + ], + ] + ).astype("float32") + + scores = np.array( + [ + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + ] + ).astype("float32") + + max_output_boxes_per_class = 2 + iou_threshold = 0.8 + score_threshold = 0.0 + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + ) + + boxes = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, -0.1, 1.0, 0.9], + [0.0, 10.0, 1.0, 11.0], + [0.0, 10.1, 1.0, 11.1], + [0.0, 100.0, 1.0, 101.0], + ] + ] + ).astype(np.float32) + scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32) + max_output_boxes_per_class = 3 + iou_threshold = 0.5 + score_threshold = 0.4 + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + ) + + if __name__ == "__main__": - test_get_valid_counts() - test_multibox_prior() - test_multibox_detection() - test_roi_align() - test_roi_pool() - test_proposal() - test_non_max_suppression() + # test_get_valid_counts() + # test_multibox_prior() + # test_multibox_detection() + # test_roi_align() + # test_roi_pool() + # test_proposal() + # test_non_max_suppression() + test_all_class_non_max_suppression() From d980761584dea84df4b2981f976217247501b54b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 16:24:00 +0900 Subject: [PATCH 02/36] add c++ boilarplate --- include/tvm/relay/attrs/vision.h | 10 +++--- src/relay/op/vision/nms.cc | 61 ++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 4a96d391430e..9d3db27ff85f 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -86,8 +86,6 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in non_maximum_suppression operator */ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { - Optional max_output_size; - Optional iou_threshold; bool force_suppress; int top_k; int coord_start; @@ -97,8 +95,6 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, "relay.attrs.AllClassNonMaximumSuppressionAttrs") { + } +}; + /*! \brief Attributes used in roi_align operators */ struct ROIAlignAttrs : public tvm::AttrsNode { Array pooled_size; diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 9316fecddca7..cd36b1983e64 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -23,6 +23,7 @@ */ #include #include +#include namespace tvm { namespace relay { @@ -132,5 +133,65 @@ ignore class_id axis. .set_support_level(5) .add_type_rel("NMS", NMSRel); +TVM_REGISTER_NODE_TYPE(AllClassNonMaximumSuppressionAttrs); + +bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 6); + const auto* boxes = types[0].as(); + if (boxes == nullptr) return false; + const auto* scores = types[1].as(); + if (scores == nullptr) return false; + + const auto& boxes_shape = boxes->shape; + const auto& scores_shape = scores->shape; + ICHECK_EQ(boxes_shape.size(), 3) << "Input boxes should be 3-D."; + ICHECK_EQ(scores_shape.size(), 3) << "Input scores count should be 3-D."; + + IndexExpr batch = boxes_shape[0]; + IndexExpr num_classes = scores_shape[1]; + IndexExpr num_boxes = boxes_shape[2]; + + IndexExpr num_total_boxes = Any(); + if (!batch.as() && !num_boxes.as()) { + num_total_boxes = batch * num_classes * num_boxes; + } + + // assign output type + std::vector fields; + std::vector oshape{num_total_boxes, 3}; + fields.push_back(TensorType(oshape, DataType::Int(32))); + std::vector countshape{1}; + fields.push_back(TensorType(countshape, DataType::Int(32))); + reporter->Assign(types[5], TupleType(Array(fields))); + return true; +} + +Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, + Expr score_threshold) { + auto attrs = make_object(); + static const Op& op = Op::Get("vision.all_class_non_max_suppression"); + return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, + Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.vision._make.all_class_non_max_suppression") + .set_body_typed(MakeAllClassNMS); + +RELAY_REGISTER_OP("vision.all_class_non_max_suppression") + .describe(R"doc(Non-maximum suppression. The input boxes should +be in the format of [class_id, score, left, top, right, bottom] +or [score, left, top, right, bottom]. Set id_index to be -1 to +ignore class_id axis. +)doc" TVM_ADD_FILELINE) + .set_num_inputs(5) + .add_argument("data", "Tensor", "Input data.") + .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") + .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.") + .add_argument("max_output_size", "Tensor", "Max number of output valid boxes.") + .add_argument("iou_threshold", "Tensor", "Threshold for box overlap.") + .set_support_level(5) + .add_type_rel("AllClassNMS", AllClassNMSRel); + } // namespace relay } // namespace tvm From a1c3bf6903aaf06cd1499b4eaeae7eab5c19d46b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 16:39:46 +0900 Subject: [PATCH 03/36] add python boilarpolate --- python/tvm/relay/op/op_attrs.py | 5 +++++ python/tvm/relay/op/strategy/cuda.py | 12 ++++++++++++ python/tvm/relay/op/strategy/generic.py | 14 ++++++++++++++ python/tvm/relay/op/vision/_vision.py | 19 +++++++++++++++++++ python/tvm/relay/op/vision/nms.py | 12 ++++++++++++ 5 files changed, 62 insertions(+) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 41076817b374..4cc6e0f26b91 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -304,6 +304,11 @@ class NonMaximumSuppressionAttrs(Attrs): """Attributes for vision.non_maximum_suppression""" +@tvm._ffi.register_object("relay.attrs.AllClassNonMaximumSuppressionAttrs") +class AllClassNonMaximumSuppressionAttrs(Attrs): + """Attributes for vision.all_classnon_maximum_suppression""" + + @tvm._ffi.register_object("relay.attrs.ROIAlignAttrs") class ROIAlignAttrs(Attrs): """Attributes for vision.roi_align""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 1a6742526607..b057b9cc6954 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -946,6 +946,18 @@ def nms_strategy_cuda(attrs, inputs, out_type, target): return strategy +@nms_strategy.register(["cuda", "gpu"]) +def all_class_nms_strategy_cuda(attrs, inputs, out_type, target): + """nms cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_nms(topi.cuda.all_class_non_max_suppression), + wrap_topi_schedule(topi.cuda.schedule_all_class_non_max_suppression), + name="all_class_nms.cuda", + ) + return strategy + + @roi_align_strategy.register(["cuda", "gpu"]) def roi_align_strategy_cuda(attrs, inputs, out_type, target): """roi_align cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 322a3607904f..e1c7ce0b2756 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1058,6 +1058,20 @@ def nms_strategy(attrs, inputs, out_type, target): return strategy +@override_native_generic_func("all_class_non_max_suppression_strategy") +def all_class_nms_strategy(attrs, inputs, out_type, target): + """all class nms generic strategy""" + # TODO + assert False + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_nms(topi.vision.non_max_suppression), + wrap_topi_schedule(topi.generic.schedule_nms), + name="nms.generic", + ) + return strategy + + # roi_align def wrap_compute_roi_align(topi_compute): """wrap roi_align topi compute""" diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 9c8c853fa3d2..13e716a4b500 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -45,6 +45,9 @@ reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy) reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE) +reg.register_strategy("vision.all_non_max_suppression", strategy.all_class_nms_strategy) +reg.register_pattern("vision.all_non_max_suppression", OpPattern.OPAQUE) + @script def _get_valid_counts_shape_func(data_shape): @@ -85,6 +88,22 @@ def nms_shape_func(attrs, inputs, _): return [topi.math.identity(inputs[0])] +@script +def _all_class_nms_shape_func(boxes_shape, scores_shape): + out_shape = output_tensor((2,), "int64") + count_shape = output_tensor((1,), "int64") + + out_shape[0] = boxes_shape[0] * scores_shape[1] * boxes_shape[2] + out_shape[1] = 3 + count_shape[0] = int64(1) + return out_shape, count_shape + + +@reg.register_shape_func("vision.all_class_non_max_suppression", False) +def all_class_nms_shape_func(attrs, inputs, _): + return _all_class_nms_shape_func(inputs[0], inputs[1]) + + @script def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size): out = output_tensor((4,), "int64") diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 0a3df40b99df..dec892deaeba 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -149,3 +149,15 @@ def non_max_suppression( if return_indices: return expr.TupleWrapper(out, 2) return out + + +def all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 +): + if not isinstance(max_output_boxes_per_class, expr.Expr): + max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") + if not isinstance(iou_threshold, expr.Expr): + iou_threshold = expr.const(iou_threshold, "float32") + return _make.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + ) From 8af8079d95d7e079174b16db04c656bf354a863e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 16:39:54 +0900 Subject: [PATCH 04/36] update onnx frontend --- python/tvm/relay/frontend/onnx.py | 222 +----------------------------- 1 file changed, 3 insertions(+), 219 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 09525a64ac05..b3aff6bbbb9a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2498,226 +2498,10 @@ def conditionally_squeeze_scalar(x): iou_threshold = conditionally_squeeze_scalar(iou_threshold) score_threshold = conditionally_squeeze_scalar(score_threshold) - ## prepare utility constants - zero = _op.const(np.array([0]), dtype="int64") - one = _op.const(np.array([1]), dtype="int64") - two = _op.const(np.array([2]), dtype="int64") - three = _op.const(np.array([3]), dtype="int64") - three_ones = _op.const(np.array([1, 1, 1]), dtype="int64") - four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64") - - ## First loop: split by class and perform NMS - # Create Loop Vars - i = _expr.var("i", shape=(1,), dtype="int64") - scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype) - boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype) - max_output_boxes_per_class_var = _expr.var( - "max_output_boxes_per_class_var", shape=(), dtype="int64" + nms_out = _op.vision.all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32") - score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32") - B = _expr.var("B", shape=(1,), dtype="int64") - C = _expr.var("C", shape=(1,), dtype="int64") - S = _expr.var("S", shape=(1,), dtype="int64") - # Outputs of first loop should be padded nms values shape (B, C, S, 3) - onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") - # and sizes of valid outputs, shape (B, C, 1) - nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") - - def _first_cond( - i, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - onnx_out, - nms_size_out, - ): - # Loop over classes, end when i == C - return _op.take(_op.less(i, C), _expr.const(0)) - - def _first_body( - i, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - onnx_out, - nms_size_out, - ): - # slice to get current class - begin = _op.concatenate([zero, i, zero], axis=0) - end = _op.concatenate([B, i + one, S], axis=0) - class_scores = _op.strided_slice(scores, begin, end, three_ones) - class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1) - # combine scores and boxes - data = _op.concatenate([class_scores, boxes], axis=-1) - - # get valid counts - ct, data, indices = _op.vision.get_valid_counts( - data, score_threshold=score_threshold, id_index=-1, score_index=0 - ) - # reason why using get_valid_counts is for inference performance - # ONNX NMS doesn't have parameter top_k - top_k = -1 - # ONNX doesn't have class id for nms input - score_index = 0 - # perform nms on current class - nms_ret = _op.vision.non_max_suppression( - data=data, - valid_count=ct, - indices=indices, - max_output_size=max_output_boxes_per_class, - iou_threshold=iou_threshold, - force_suppress=True, - top_k=top_k, - coord_start=1, - score_index=score_index, - id_index=-1, - return_indices=True, - invalid_to_bottom=False, - ) - # partially prepare ONNX output format by labeling batch_num, class_id - nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1) - batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1) - batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64")) - batch_num = _op.expand_dims(batch_num, -1, 1) - class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64")) - new_onnx_out = _op.concatenate( - [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1 - ) - new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1) - # store valid nms outputs for this class - nms_size = _op.cast(nms_ret[1], "int64") - nms_size = _op.expand_dims(nms_size, 1, 1) - return [ - i + one, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - _op.concatenate([onnx_out, new_onnx_out], axis=1), - _op.concatenate([nms_size_out, nms_size], axis=1), - ] - - # create the first loop - first_loop = _loops.while_loop( - _first_cond, - [ - i, - scores_var, - boxes_var, - B, - C, - S, - max_output_boxes_per_class_var, - iou_threshold_var, - score_threshold_var, - onnx_out, - nms_size_out, - ], - _first_body, - ) - - ## Second loop slices outputs of the first loop for valid boxes and - ## concats in the order ONNX wants - # Second inner Loop Vars - i = _expr.var("i", shape=(1,), dtype="int64") - j = _expr.var("j", shape=(1,), dtype="int64") - B = _expr.var("B", shape=(1,), dtype="int64") - C = _expr.var("C", shape=(1,), dtype="int64") - # Outputs of first loop should be padded nms values shape (B, C, 3) - onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") - # and sizes of valid outputs, shape (B, C, 1) - nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") - out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") - - def _inner_cond(i, j, C, onnx_out, nms_size, out): - # inner loop over number of classes - return _op.take(_op.less(j, C), _expr.const(0)) - - def _inner_body(i, j, C, onnx_out, nms_size, out): - # slice to get current batch and class for valid box indicator - start = _op.concatenate([i, j + one, zero], axis=0) - end = _op.concatenate([i + one, j + two, one], axis=0) - num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1]) - # slice to get current batch, class, and valid outputs - start = _op.concatenate([i, j + one, zero, zero], axis=0) - end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0) - new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1]) - return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0) - - inner_loop = _loops.while_loop( - _inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body - ) - - # Second Outer Loop Vars - i = _expr.var("i", shape=(1,), dtype="int64") - j = _expr.var("j", shape=(1,), dtype="int64") - B = _expr.var("B", shape=(1,), dtype="int64") - C = _expr.var("C", shape=(1,), dtype="int64") - # Outputs of first loop should be padded nms values shape (B, C, 3) - onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64") - # and sizes of valid outputs, shape (B, C, 1) - nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64") - out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64") - - def _outer_cond(i, B, C, onnx_out, nms_size_out, out): - # Outer loop is over batch size - return _op.take(_op.less(i, B), _expr.const(0)) - - def _outer_body(i, B, C, onnx_out, nms_size_out, out): - # Outer loop just calls inner loop - init_count = _op.const(np.array([0]), dtype="int64") - inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out) - return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5) - - # Create the second loop - outer_loop = _loops.while_loop( - _outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body - ) - - # Call the first loop, perform NMS - B, C, S = _op.split(shape_of(scores, dtype="int64"), 3) - init_count = _op.const(np.array([0]), dtype="int64") - init_onnx_out = _op.const([1], dtype="int64") - init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0)) - init_nms_size_out = _op.const([1], dtype="int64") - init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0)) - loop_vals = first_loop( - init_count, - scores, - boxes, - B, - C, - S, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - init_onnx_out, - init_nms_size_out, - ) - onnx_output = _expr.TupleGetItem(loop_vals, 9) - nms_size_output = _expr.TupleGetItem(loop_vals, 10) - - # Call the second loop, rework outputs into correct form - init_count = _op.const(np.array([0]).astype("int64"), dtype="int64") - init_out = _op.const(np.array([1, 1, 1]).reshape([1, 3]).astype("int64"), dtype="int64") - loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out) - loop_out = _expr.TupleGetItem(loop_vals, 5) - return _op.strided_slice(loop_out, [1, 0], shape_of(loop_out), [1, 1]) + return _op.strided_slice(nms_out[0], [0], [nms_out[1]]) class ATen(OnnxOpConverter): From d26d5b986f18a73e2a6a9baed56ae00438d70bfa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 16:58:29 +0900 Subject: [PATCH 05/36] fixing --- python/tvm/relay/frontend/onnx.py | 4 +++- python/tvm/relay/op/strategy/cuda.py | 2 +- python/tvm/relay/op/strategy/generic.py | 28 ++++++++++++++++++------- python/tvm/relay/op/vision/_vision.py | 4 ++-- python/tvm/relay/op/vision/nms.py | 3 ++- python/tvm/topi/cuda/nms.py | 2 +- 6 files changed, 29 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b3aff6bbbb9a..cf02a8259bfc 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2501,7 +2501,9 @@ def conditionally_squeeze_scalar(x): nms_out = _op.vision.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - return _op.strided_slice(nms_out[0], [0], [nms_out[1]]) + num_detections = _op.squeeze(nms_out[1], axis=[0]) + return nms_out[0] + #return _op.strided_slice(nms_out[0], _expr.const([0]), num_detections) class ATen(OnnxOpConverter): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b057b9cc6954..4c203143f54b 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -951,7 +951,7 @@ def all_class_nms_strategy_cuda(attrs, inputs, out_type, target): """nms cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_nms(topi.cuda.all_class_non_max_suppression), + wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), wrap_topi_schedule(topi.cuda.schedule_all_class_non_max_suppression), name="all_class_nms.cuda", ) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index e1c7ce0b2756..7eae8d1398ce 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1000,10 +1000,6 @@ def wrap_compute_nms(topi_compute): def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[3] iou_threshold = inputs[4] - if attrs.max_output_size is not None: - max_output_size = attrs.max_output_size - if attrs.iou_threshold is not None: - iou_threshold = get_const_float(attrs.iou_threshold) return_indices = bool(get_const_int(attrs.return_indices)) force_suppress = bool(get_const_int(attrs.force_suppress)) top_k = get_const_int(attrs.top_k) @@ -1058,16 +1054,32 @@ def nms_strategy(attrs, inputs, out_type, target): return strategy +def wrap_compute_all_class_nms(topi_compute): + """wrap nms topi compute""" + def _compute_nms(attrs, inputs, out_type): + max_output_size = inputs[2] + iou_threshold = inputs[3] + score_threshold = inputs[4] + return topi_compute( + inputs[0], + inputs[1], + max_output_size, + iou_threshold, + score_threshold + ) + + return _compute_nms + + @override_native_generic_func("all_class_non_max_suppression_strategy") def all_class_nms_strategy(attrs, inputs, out_type, target): """all class nms generic strategy""" # TODO - assert False strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_nms(topi.vision.non_max_suppression), - wrap_topi_schedule(topi.generic.schedule_nms), - name="nms.generic", + wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), + wrap_topi_schedule(topi.cuda.schedule_nms), + name="all_class_nms.generic", ) return strategy diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 13e716a4b500..a13c82b37d21 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -45,8 +45,8 @@ reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy) reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE) -reg.register_strategy("vision.all_non_max_suppression", strategy.all_class_nms_strategy) -reg.register_pattern("vision.all_non_max_suppression", OpPattern.OPAQUE) +reg.register_strategy("vision.all_class_non_max_suppression", strategy.all_class_nms_strategy) +reg.register_pattern("vision.all_class_non_max_suppression", OpPattern.OPAQUE) @script diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index dec892deaeba..0a61ca962b02 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -158,6 +158,7 @@ def all_class_non_max_suppression( max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") if not isinstance(iou_threshold, expr.Expr): iou_threshold = expr.const(iou_threshold, "float32") - return _make.all_class_non_max_suppression( + out = _make.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) + return expr.TupleWrapper(out, 2) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 6f1ee7c0e72d..568bc0687fe8 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1203,4 +1203,4 @@ def all_class_non_max_suppression( num_class, selected_indices, num_detections, row_offsets ) - return selected_indices, num_total_detections + return [selected_indices, num_total_detections] From 0c71339b5a1a1c15780fcd2721459b553fb8f2fb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 17:13:03 +0900 Subject: [PATCH 06/36] update onnx frontend --- python/tvm/relay/frontend/onnx.py | 10 +++++++--- python/tvm/relay/op/strategy/generic.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cf02a8259bfc..0ade02aa0ea8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2501,9 +2501,13 @@ def conditionally_squeeze_scalar(x): nms_out = _op.vision.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - num_detections = _op.squeeze(nms_out[1], axis=[0]) - return nms_out[0] - #return _op.strided_slice(nms_out[0], _expr.const([0]), num_detections) + + zero = _op.const(np.array([0]), dtype="int32") + three = _op.const(np.array([3]), dtype="int32") + begin = _op.concatenate([zero, zero], axis=0) + end = _op.concatenate([nms_out[1], three], axis=0) + strides = _op.const(np.array([1, 1]), dtype="int32") + return _op.strided_slice(nms_out[0], begin, end, strides) class ATen(OnnxOpConverter): diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 7eae8d1398ce..6924196b0585 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1078,7 +1078,7 @@ def all_class_nms_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), - wrap_topi_schedule(topi.cuda.schedule_nms), + wrap_topi_schedule(topi.cuda.schedule_all_class_non_max_suppression), name="all_class_nms.generic", ) return strategy From a40337af57e3e68f4b483718b627d8380fa39682 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 17:26:03 +0900 Subject: [PATCH 07/36] fix shape --- src/relay/op/vision/nms.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index cd36b1983e64..bca359f5a632 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -150,7 +150,7 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs IndexExpr batch = boxes_shape[0]; IndexExpr num_classes = scores_shape[1]; - IndexExpr num_boxes = boxes_shape[2]; + IndexExpr num_boxes = boxes_shape[1]; IndexExpr num_total_boxes = Any(); if (!batch.as() && !num_boxes.as()) { From 71370cf981c24d481d230e407f79a8a42e7ac6ad Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 17:27:59 +0900 Subject: [PATCH 08/36] minor update --- python/tvm/relay/frontend/onnx.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0ade02aa0ea8..ad7708e39c65 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2456,17 +2456,6 @@ class NonMaxSuppression(OnnxOpConverter): @classmethod def _impl_v10(cls, inputs, attr, params): - """ - High level note: ONNX implements what TF calls combined_non_max_suppression - It passes in scores for each box for every class in the output and expects boxes to be - analyzed for each class independently - - It also asks for the data to be returned in a particular format. - - To support these, we implement a series of lops: - The first loop splits over class number, performs NMS, and collects the outputs. - The second (nested) loop takes the outputs and transforms them into the format ONNX wants - """ # Get parameter values boxes = inputs[0] scores = inputs[1] @@ -2474,8 +2463,6 @@ def _impl_v10(cls, inputs, attr, params): iou_threshold = inputs[3] score_threshold = inputs[4] - dtype = infer_type(boxes).checked_type.dtype - if "center_point_box" in attr: if attr["center_point_box"] != 0: raise NotImplementedError( From 15d3bd09e7cd4b2728e12049f7d16e9089120103 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Apr 2021 18:06:43 +0900 Subject: [PATCH 09/36] fix --- python/tvm/relay/frontend/onnx.py | 3 ++- python/tvm/relay/op/strategy/cuda.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ad7708e39c65..e667446c1758 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2494,7 +2494,8 @@ def conditionally_squeeze_scalar(x): begin = _op.concatenate([zero, zero], axis=0) end = _op.concatenate([nms_out[1], three], axis=0) strides = _op.const(np.array([1, 1]), dtype="int32") - return _op.strided_slice(nms_out[0], begin, end, strides) + # TODO: fix cast + return _op.cast(_op.strided_slice(nms_out[0], begin, end, strides), "int64") class ATen(OnnxOpConverter): diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 4c203143f54b..f944aaf720c2 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -946,7 +946,7 @@ def nms_strategy_cuda(attrs, inputs, out_type, target): return strategy -@nms_strategy.register(["cuda", "gpu"]) +@all_class_nms_strategy.register(["cuda", "gpu"]) def all_class_nms_strategy_cuda(attrs, inputs, out_type, target): """nms cuda strategy""" strategy = _op.OpStrategy() From 837ce76902c0a5e2945df41eafd5a448f6024e77 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 05:07:40 +0900 Subject: [PATCH 10/36] fix shape func --- python/tvm/relay/op/vision/_vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index a13c82b37d21..7a31bce5ad49 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -93,7 +93,7 @@ def _all_class_nms_shape_func(boxes_shape, scores_shape): out_shape = output_tensor((2,), "int64") count_shape = output_tensor((1,), "int64") - out_shape[0] = boxes_shape[0] * scores_shape[1] * boxes_shape[2] + out_shape[0] = boxes_shape[0] * scores_shape[1] * boxes_shape[1] out_shape[1] = 3 count_shape[0] = int64(1) return out_shape, count_shape From e26bb4d6f869bbe8048a561bf3074f29b99f6bfa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 05:34:09 +0900 Subject: [PATCH 11/36] fix for no box --- python/tvm/topi/cuda/nms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 568bc0687fe8..1f68bd27a449 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -399,6 +399,9 @@ def nms_inner_loop(ib, i, j, nkeep): with ib.if_scope(tx + 0 == 0): num_valid_boxes[i] = num_valid_boxes_local[0] + with ib.else_scope(): + num_valid_boxes[i] = 0 + return ib.get() From 65b5bba301d428d5f2f095a86e86cf95ae098f1d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 05:38:36 +0900 Subject: [PATCH 12/36] more fix --- python/tvm/relay/frontend/onnx.py | 6 +++--- python/tvm/topi/cuda/nms.py | 2 +- src/relay/op/vision/nms.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e667446c1758..6db237c39e25 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2489,11 +2489,11 @@ def conditionally_squeeze_scalar(x): boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - zero = _op.const(np.array([0]), dtype="int32") - three = _op.const(np.array([3]), dtype="int32") + zero = _op.const(np.array([0]), dtype="int64") + three = _op.const(np.array([3]), dtype="int64") begin = _op.concatenate([zero, zero], axis=0) end = _op.concatenate([nms_out[1], three], axis=0) - strides = _op.const(np.array([1, 1]), dtype="int32") + strides = _op.const(np.array([1, 1]), dtype="int64") # TODO: fix cast return _op.cast(_op.strided_slice(nms_out[0], begin, end, strides), "int64") diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 1f68bd27a449..37df7d7b3e89 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1198,7 +1198,7 @@ def all_class_non_max_suppression( selected_indices, num_detections = _run_all_class_nms( boxes, sorted_scores, sorted_indices, valid_count, max_output_boxes_per_class, iou_threshold ) - num_detections = expand_dims(num_detections, axis=0) + num_detections = cast(expand_dims(num_detections, axis=0), "int64") row_offsets, num_total_detections = exclusive_scan(num_detections, return_reduction=True) diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index bca359f5a632..8cc895905253 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -162,7 +162,7 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs std::vector oshape{num_total_boxes, 3}; fields.push_back(TensorType(oshape, DataType::Int(32))); std::vector countshape{1}; - fields.push_back(TensorType(countshape, DataType::Int(32))); + fields.push_back(TensorType(countshape, DataType::Int(64))); reporter->Assign(types[5], TupleType(Array(fields))); return true; } From 253629ab5e6e4c937b50dc172c71c5c367c2cfef Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 06:28:10 +0900 Subject: [PATCH 13/36] made things 64 bit --- python/tvm/relay/frontend/onnx.py | 3 +-- python/tvm/topi/cuda/nms.py | 16 ++++++++-------- python/tvm/topi/cuda/scan.py | 2 +- src/relay/op/vision/nms.cc | 2 +- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6db237c39e25..0aae615e249f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2494,8 +2494,7 @@ def conditionally_squeeze_scalar(x): begin = _op.concatenate([zero, zero], axis=0) end = _op.concatenate([nms_out[1], three], axis=0) strides = _op.const(np.array([1, 1]), dtype="int64") - # TODO: fix cast - return _op.cast(_op.strided_slice(nms_out[0], begin, end, strides), "int64") + return _op.strided_slice(nms_out[0], begin, end, strides) class ATen(OnnxOpConverter): diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 37df7d7b3e89..ffda956f6952 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -341,9 +341,9 @@ def _nms_loop( ib.scope_attr(tx, "thread_extent", nthread_tx) num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" + "int64", (1,), name="num_valid_boxes_local", scope="local" ) - num_valid_boxes_local[0] = 0 + num_valid_boxes_local[0] = cast(0, "int64") def nms_inner_loop(ib, i, j, nkeep): # The box j is valid, invalidate other boxes that overlap with j above iou_threshold @@ -400,7 +400,7 @@ def nms_inner_loop(ib, i, j, nkeep): num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): - num_valid_boxes[i] = 0 + num_valid_boxes[i] = cast(0, "int64") return ib.get() @@ -1054,7 +1054,7 @@ def calc_overlap(i, j, k): def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): - box_indices[i, num_current_valid_box] = sorted_indices[i, j] + box_indices[i, num_current_valid_box] = cast(sorted_indices[i, j], "int64") def max_output_size(batch_class_index): return max_output_size_per_class @@ -1116,7 +1116,7 @@ def _run_all_class_nms( outs[0], # box_indices outs[1], # num_valid_boxes ), - dtype=["int32", "int32"], + dtype=["int64", "int64"], in_buffers=[ boxes_buf, sorted_scores_buf, @@ -1151,7 +1151,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro with ib.new_scope(): idx = bx * nthread_tx + tx - idy = by + idy = cast(by, "int64") batch_id = idy // num_class class_id = idy % num_class with ib.if_scope(idx < num_detections[idy]): @@ -1179,7 +1179,7 @@ def _collect_selected_indices(num_class, selected_indices, num_detections, row_o [(batch_class * num_boxes, 3)], [selected_indices, num_detections, row_offsets], lambda ins, outs: _collect_selected_indices_ir(num_class, ins[0], ins[1], ins[2], outs[0]), - dtype=["int32"], + dtype=["int64"], in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], name="collect_indices", tag="collect_indices", @@ -1198,7 +1198,7 @@ def all_class_non_max_suppression( selected_indices, num_detections = _run_all_class_nms( boxes, sorted_scores, sorted_indices, valid_count, max_output_boxes_per_class, iou_threshold ) - num_detections = cast(expand_dims(num_detections, axis=0), "int64") + num_detections = expand_dims(num_detections, axis=0) row_offsets, num_total_detections = exclusive_scan(num_detections, return_reduction=True) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 3240ebcd515c..ba5204c84f04 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -81,7 +81,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): if reduction is not None: - reduction[bx] = 0 + reduction[bx] = cast(identity_value, "int64") with ib.else_scope(): with ib.new_scope(): nthread_tx = max_threads diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 8cc895905253..50e1ada9dcd6 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -160,7 +160,7 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs // assign output type std::vector fields; std::vector oshape{num_total_boxes, 3}; - fields.push_back(TensorType(oshape, DataType::Int(32))); + fields.push_back(TensorType(oshape, DataType::Int(64))); std::vector countshape{1}; fields.push_back(TensorType(countshape, DataType::Int(64))); reporter->Assign(types[5], TupleType(Array(fields))); From 9cb2505aa5d7c6e477ab74d3b300b0cd413bfff8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 06:49:46 +0900 Subject: [PATCH 14/36] int64 tweak --- python/tvm/topi/cuda/nms.py | 16 +++++++++------- python/tvm/topi/cuda/scan.py | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index ffda956f6952..dc43ee6c1e57 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -341,9 +341,9 @@ def _nms_loop( ib.scope_attr(tx, "thread_extent", nthread_tx) num_valid_boxes_local = ib.allocate( - "int64", (1,), name="num_valid_boxes_local", scope="local" + "int32", (1,), name="num_valid_boxes_local", scope="local" ) - num_valid_boxes_local[0] = cast(0, "int64") + num_valid_boxes_local[0] = 0 def nms_inner_loop(ib, i, j, nkeep): # The box j is valid, invalidate other boxes that overlap with j above iou_threshold @@ -400,7 +400,7 @@ def nms_inner_loop(ib, i, j, nkeep): num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): - num_valid_boxes[i] = cast(0, "int64") + num_valid_boxes[i] = 0 return ib.get() @@ -1054,7 +1054,7 @@ def calc_overlap(i, j, k): def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): - box_indices[i, num_current_valid_box] = cast(sorted_indices[i, j], "int64") + box_indices[i, num_current_valid_box] = sorted_indices[i, j] def max_output_size(batch_class_index): return max_output_size_per_class @@ -1116,7 +1116,7 @@ def _run_all_class_nms( outs[0], # box_indices outs[1], # num_valid_boxes ), - dtype=["int64", "int64"], + dtype=["int32", "int32"], in_buffers=[ boxes_buf, sorted_scores_buf, @@ -1157,7 +1157,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro with ib.if_scope(idx < num_detections[idy]): out[row_offsets[idy] + idx, 0] = batch_id out[row_offsets[idy] + idx, 1] = class_id - out[row_offsets[idy] + idx, 2] = selected_indices[idy, idx] + out[row_offsets[idy] + idx, 2] = cast(selected_indices[idy, idx], "int64") return ib.get() @@ -1200,7 +1200,9 @@ def all_class_non_max_suppression( ) num_detections = expand_dims(num_detections, axis=0) - row_offsets, num_total_detections = exclusive_scan(num_detections, return_reduction=True) + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) selected_indices = _collect_selected_indices( num_class, selected_indices, num_detections, row_offsets diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index ba5204c84f04..5d3798e3d27b 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -81,7 +81,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i ib.scope_attr(bx, "thread_extent", batch_size) with ib.if_scope(bx < batch_size): if reduction is not None: - reduction[bx] = cast(identity_value, "int64") + reduction[bx] = cast(identity_value, out_dtype) with ib.else_scope(): with ib.new_scope(): nthread_tx = max_threads @@ -393,7 +393,7 @@ def do_scan(data, output_dtype): lambda ins, outs: exclusive_scan_ir( ins[0], outs[0], outs[1], binop=binop, identity_value=identity_value ), - dtype=[data.dtype, output_dtype], + dtype=[output_dtype, output_dtype], in_buffers=[data_buf], name="exclusive_scan", tag="exclusive_scan_gpu", From ac5d79b22095c1d49420eeab8460a69631c68298 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 11:10:23 +0900 Subject: [PATCH 15/36] max_output_size doesn't need to be a callback --- python/tvm/topi/cuda/nms.py | 10 +++------- tests/python/topi/python/test_topi_vision.py | 6 +++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index dc43ee6c1e57..7a0618094383 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -312,8 +312,8 @@ def _nms_loop( num_anchors, top_k, iou_threshold, + max_output_size, valid_count, - get_max_output_size_func, on_new_valid_box_func, on_new_invalidated_box_func, needs_bbox_check_func, @@ -373,7 +373,6 @@ def nms_inner_loop(ib, i, j, nkeep): i = by - max_output_size = get_max_output_size_func(i) nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): @@ -632,8 +631,8 @@ def needs_bbox_check(i, j, k): num_anchors, top_k, iou_threshold, + max_output_size, valid_count, - lambda _: max_output_size, on_new_valid_box, on_new_invalidated_box, needs_bbox_check, @@ -1056,9 +1055,6 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): box_indices[i, num_current_valid_box] = sorted_indices[i, j] - def max_output_size(batch_class_index): - return max_output_size_per_class - def on_new_invalidated_box(i, k): pass @@ -1071,8 +1067,8 @@ def needs_bbox_check(i, j, k): num_anchors, tvm.tir.IntImm("int32", -1), # top_k iou_threshold, + max_output_size_per_class, valid_count, - max_output_size, on_new_valid_box, on_new_invalidated_box, needs_bbox_check, diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index f289de7586e4..6a712ef605a6 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -651,15 +651,15 @@ def check_device(target): tvm_boxes = tvm.nd.array(boxes_np, dev) tvm_scores = tvm.nd.array(scores_np, dev) - selected_indices = tvm.nd.array(np.zeros((batch * num_class * num_boxes, 3), "int32"), dev) - num_detections = tvm.nd.array(np.zeros((1,), "int32"), dev) + selected_indices = tvm.nd.array(np.zeros((batch * num_class * num_boxes, 3), "int64"), dev) + num_detections = tvm.nd.array(np.zeros((1,), "int64"), dev) f = tvm.build(s, [boxes, scores, out[0], out[1]], target) f(tvm_boxes, tvm_scores, selected_indices, num_detections) print(selected_indices.asnumpy()[:num_detections.asnumpy()[0]]) # tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) - for target in ["vulkan"]: + for target in ["cuda"]: check_device(target) From fd868a11978188b2cf5d9125faba40f8ca994df8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 11:18:55 +0900 Subject: [PATCH 16/36] remove all_class_nms schedule --- python/tvm/relay/op/strategy/cuda.py | 2 +- python/tvm/relay/op/strategy/generic.py | 2 +- python/tvm/topi/cuda/vision.py | 20 +------------------- tests/python/topi/python/test_topi_vision.py | 4 ++-- 4 files changed, 5 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index f944aaf720c2..198b98d7f039 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -952,7 +952,7 @@ def all_class_nms_strategy_cuda(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), - wrap_topi_schedule(topi.cuda.schedule_all_class_non_max_suppression), + wrap_topi_schedule(topi.cuda.schedule_nms), name="all_class_nms.cuda", ) return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 6924196b0585..7eae8d1398ce 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1078,7 +1078,7 @@ def all_class_nms_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), - wrap_topi_schedule(topi.cuda.schedule_all_class_non_max_suppression), + wrap_topi_schedule(topi.cuda.schedule_nms), name="all_class_nms.generic", ) return strategy diff --git a/python/tvm/topi/cuda/vision.py b/python/tvm/topi/cuda/vision.py index 481c6bdbb62d..88983ab89f76 100644 --- a/python/tvm/topi/cuda/vision.py +++ b/python/tvm/topi/cuda/vision.py @@ -32,7 +32,7 @@ def _default_schedule(outs): scheduled_ops = [] def traverse(op): - if tag.is_broadcast(op.tag) or op.tag in ["bbox_score", "sorted_bbox"]: + if tag.is_injective(op.tag) or op.tag in ["bbox_score", "sorted_bbox"]: schedule_injective_from_existing(s, op.output(0)) for tensor in op.input_tensors: if tensor.op.input_tensors and tensor.op not in scheduled_ops: @@ -78,24 +78,6 @@ def schedule_nms(outs): return _default_schedule(outs) -def schedule_all_class_non_max_suppression(outs): - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def traverse(op): - if tag.is_injective(op.tag): - schedule_injective_from_existing(s, op.output(0)) - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - scheduled_ops.append(op) - - for out in outs: - traverse(out.op) - return s - - def schedule_multibox_prior(outs): """Schedule for multibox_prior operator. diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 6a712ef605a6..6a1674b91f5d 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -66,7 +66,7 @@ } _all_class_nms_implement = { - "gpu": (topi.cuda.all_class_non_max_suppression, topi.cuda.schedule_all_class_non_max_suppression), + "gpu": (topi.cuda.all_class_non_max_suppression, topi.cuda.schedule_nms), } @@ -474,7 +474,7 @@ def check_device(target): tvm_val = tvm_b.asnumpy() tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3, atol=1e-4) - for target in ["llvm", "cuda", "opencl"]: + for target in ["cuda"]: check_device(target) From adaaf50703f903998ebcf5395bbb0f74116cbbe4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 11:24:44 +0900 Subject: [PATCH 17/36] minor simplify --- python/tvm/relay/frontend/onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0aae615e249f..6816e27f3803 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2489,9 +2489,8 @@ def conditionally_squeeze_scalar(x): boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) - zero = _op.const(np.array([0]), dtype="int64") three = _op.const(np.array([3]), dtype="int64") - begin = _op.concatenate([zero, zero], axis=0) + begin = _op.const(np.array([0, 0]), dtype="int64") end = _op.concatenate([nms_out[1], three], axis=0) strides = _op.const(np.array([1, 1]), dtype="int64") return _op.strided_slice(nms_out[0], begin, end, strides) From 83aa4c2983966b890ae99f49451a741ff0f821a7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 15:03:22 +0900 Subject: [PATCH 18/36] remove expand_dim --- python/tvm/topi/cuda/nms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 7a0618094383..83c7399004b0 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1097,7 +1097,7 @@ def _run_all_class_nms( ) return te.extern( - [(batch_class, num_boxes), (batch_class,)], + [(batch_class, num_boxes), (1, batch_class)], [boxes, sorted_scores, sorted_indices, valid_count], lambda ins, outs: _all_class_nms_ir( ins[0], # boxes @@ -1194,7 +1194,6 @@ def all_class_non_max_suppression( selected_indices, num_detections = _run_all_class_nms( boxes, sorted_scores, sorted_indices, valid_count, max_output_boxes_per_class, iou_threshold ) - num_detections = expand_dims(num_detections, axis=0) row_offsets, num_total_detections = exclusive_scan( num_detections, return_reduction=True, output_dtype="int64" From 0be35e671ac578cec38655d2d30f2a8809b0d6e6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 15:30:41 +0900 Subject: [PATCH 19/36] refactoring --- python/tvm/topi/cuda/nms.py | 214 +++-------------------------- python/tvm/topi/vision/nms_util.py | 210 ++++++++++++++++++++++++++++ 2 files changed, 228 insertions(+), 196 deletions(-) create mode 100644 python/tvm/topi/vision/nms_util.py diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 83c7399004b0..0a958d2d3599 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -26,8 +26,13 @@ from .scan import exclusive_scan from ..utils import ceil_div from ..math import cast -from ..transform import reshape, expand_dims -from ..broadcast import greater +from ..transform import reshape +from ..vision.nms_util import ( + calculate_overlap, + binary_search, + collect_selected_indices, + run_all_class_nms, +) def cuda_atomic_add_rule(op): @@ -268,44 +273,6 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): return [valid_count, out, out_indices] -def get_boundaries(output, box_idx): - l = tvm.te.min( - output[box_idx], - output[box_idx + 2], - ) - t = tvm.te.min( - output[box_idx + 1], - output[box_idx + 3], - ) - r = tvm.te.max( - output[box_idx], - output[box_idx + 2], - ) - b = tvm.te.max( - output[box_idx + 1], - output[box_idx + 3], - ) - return l, t, r, b - - -def calculate_overlap(out_tensor, box_a_idx, box_b_idx): - """Calculate overlap of two boxes.""" - a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) - b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) - - # Overlapping width and height - w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) - h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) - - # Overlapping area - area = h * w - - # total area of the figure formed by box a and box b - # except for overlapping area - u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area - return tvm.tir.Select(u <= 0.0, 0.0, area / u) - - def _nms_loop( ib, batch_size, @@ -965,22 +932,6 @@ def non_max_suppression( def _get_valid_box_count(scores, score_threshold): batch_classes, num_boxes = scores.shape - def binary_search(ib, y, out): - lo = ib.allocate("int32", (1,), name="lo", scope="local") - hi = ib.allocate("int32", (1,), name="hi", scope="local") - - lo[0] = 0 - hi[0] = num_boxes - - with ib.while_loop(lo[0] < hi[0]): - mid = (hi[0] + lo[0]) >> 1 - with ib.if_scope(scores[y, mid] > score_threshold): - lo[0] = mid + 1 - with ib.else_scope(): - hi[0] = mid - - out[y] = lo[0] - def searchsorted_ir(scores, valid_count): ib = tvm.tir.ir_builder.create() scores = ib.buffer_ptr(scores) @@ -996,7 +947,7 @@ def searchsorted_ir(scores, valid_count): tid = bx * max_threads + tx with ib.if_scope(tid < batch_classes): - binary_search(ib, tid, valid_count) + binary_search(ib, tid, num_boxes, scores, score_threshold, valid_count) return ib.get() @@ -1013,117 +964,6 @@ def searchsorted_ir(scores, valid_count): ) -def _all_class_nms_ir( - boxes, - sorted_scores, - sorted_indices, - valid_count, - batch_class, - num_class, - num_anchors, - iou_threshold, - max_output_size_per_class, - box_indices, - num_valid_boxes, -): - ib = tvm.tir.ir_builder.create() - boxes = ib.buffer_ptr(boxes) - sorted_scores = ib.buffer_ptr(sorted_scores) - sorted_indices = ib.buffer_ptr(sorted_indices) - valid_count = ib.buffer_ptr(valid_count) - box_indices = ib.buffer_ptr(box_indices) - num_valid_boxes = ib.buffer_ptr(num_valid_boxes) - - if isinstance(iou_threshold, float): - iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) - - if isinstance(max_output_size_per_class, int): - max_output_size_per_class = tvm.tir.const(max_output_size_per_class) - - def calc_overlap(i, j, k): - offset_j = sorted_indices[i, j] * 4 - offset_k = sorted_indices[i, k] * 4 - batch_id = i // num_class - base_bbox_idx = batch_id * num_anchors * 4 - return calculate_overlap( - boxes, - base_bbox_idx + offset_j, - base_bbox_idx + offset_k, - ) - - def on_new_valid_box(ib, tid, num_current_valid_box, i, j): - with ib.if_scope(tid + 0 == 0): - box_indices[i, num_current_valid_box] = sorted_indices[i, j] - - def on_new_invalidated_box(i, k): - pass - - def needs_bbox_check(i, j, k): - return tvm.tir.const(True) - - return _nms_loop( - ib, - batch_class, - num_anchors, - tvm.tir.IntImm("int32", -1), # top_k - iou_threshold, - max_output_size_per_class, - valid_count, - on_new_valid_box, - on_new_invalidated_box, - needs_bbox_check, - calc_overlap, - sorted_scores, - num_valid_boxes, - ) - - -def _run_all_class_nms( - boxes, sorted_scores, sorted_indices, valid_count, max_output_size_per_class, iou_threshold -): - batch, num_boxes, _ = boxes.shape - batch_class = sorted_scores.shape[0] - num_class = batch_class // batch - - boxes_buf = tvm.tir.decl_buffer(boxes.shape, boxes.dtype, "boxes_buf", data_alignment=8) - sorted_scores_buf = tvm.tir.decl_buffer( - sorted_scores.shape, sorted_scores.dtype, "sorted_scores_buf", data_alignment=8 - ) - sorted_indices_buf = tvm.tir.decl_buffer( - sorted_indices.shape, sorted_indices.dtype, "sorted_indices_buf", data_alignment=8 - ) - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, "int32", "valid_count_buf", data_alignment=4 - ) - - return te.extern( - [(batch_class, num_boxes), (1, batch_class)], - [boxes, sorted_scores, sorted_indices, valid_count], - lambda ins, outs: _all_class_nms_ir( - ins[0], # boxes - ins[1], # sorted_scores - ins[2], # sorted_indices - ins[3], # valid_count - batch_class, - num_class, - num_boxes, - iou_threshold, - max_output_size_per_class, - outs[0], # box_indices - outs[1], # num_valid_boxes - ), - dtype=["int32", "int32"], - in_buffers=[ - boxes_buf, - sorted_scores_buf, - sorted_indices_buf, - valid_count_buf, - ], - name="all_class_nms", - tag="all_class_nms", - ) - - def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out): batch_classes, num_boxes = selected_indices.shape @@ -1158,30 +998,6 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() -def _collect_selected_indices(num_class, selected_indices, num_detections, row_offsets): - batch_class, num_boxes = selected_indices.shape - - selected_indices_buf = tvm.tir.decl_buffer( - selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 - ) - num_detections_buf = tvm.tir.decl_buffer( - num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 - ) - row_offsets_buf = tvm.tir.decl_buffer( - row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 - ) - - return te.extern( - [(batch_class * num_boxes, 3)], - [selected_indices, num_detections, row_offsets], - lambda ins, outs: _collect_selected_indices_ir(num_class, ins[0], ins[1], ins[2], outs[0]), - dtype=["int64"], - in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], - name="collect_indices", - tag="collect_indices", - ) - - def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ): @@ -1191,16 +1007,22 @@ def all_class_non_max_suppression( sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = _run_all_class_nms( - boxes, sorted_scores, sorted_indices, valid_count, max_output_boxes_per_class, iou_threshold + selected_indices, num_detections = run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_boxes_per_class, + iou_threshold, + _nms_loop, ) row_offsets, num_total_detections = exclusive_scan( num_detections, return_reduction=True, output_dtype="int64" ) - selected_indices = _collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir ) return [selected_indices, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py new file mode 100644 index 000000000000..34b31786e635 --- /dev/null +++ b/python/tvm/topi/vision/nms_util.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + + +def get_boundaries(output, box_idx): + l = tvm.te.min( + output[box_idx], + output[box_idx + 2], + ) + t = tvm.te.min( + output[box_idx + 1], + output[box_idx + 3], + ) + r = tvm.te.max( + output[box_idx], + output[box_idx + 2], + ) + b = tvm.te.max( + output[box_idx + 1], + output[box_idx + 3], + ) + return l, t, r, b + + +def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes.""" + a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) + + # Overlapping width and height + w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) + h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t)) + + # Overlapping area + area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area + u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + return tvm.tir.Select(u <= 0.0, 0.0, area / u) + + +def binary_search(ib, y, num_boxes, scores, score_threshold, out): + lo = ib.allocate("int32", (1,), name="lo", scope="local") + hi = ib.allocate("int32", (1,), name="hi", scope="local") + + lo[0] = 0 + hi[0] = num_boxes + + with ib.while_loop(lo[0] < hi[0]): + mid = (hi[0] + lo[0]) >> 1 + with ib.if_scope(scores[y, mid] > score_threshold): + lo[0] = mid + 1 + with ib.else_scope(): + hi[0] = mid + + out[y] = lo[0] + + +def collect_selected_indices(num_class, selected_indices, num_detections, row_offsets, ir): + batch_class, num_boxes = selected_indices.shape + + selected_indices_buf = tvm.tir.decl_buffer( + selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 + ) + num_detections_buf = tvm.tir.decl_buffer( + num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 + ) + row_offsets_buf = tvm.tir.decl_buffer( + row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 + ) + + return te.extern( + [(batch_class * num_boxes, 3)], + [selected_indices, num_detections, row_offsets], + lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], outs[0]), + dtype=["int64"], + in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], + name="collect_indices", + tag="collect_indices", + ) + + +def _all_class_nms_ir( + boxes, + sorted_scores, + sorted_indices, + valid_count, + batch_class, + num_class, + num_anchors, + iou_threshold, + max_output_size_per_class, + box_indices, + num_valid_boxes, + nms_loop +): + ib = tvm.tir.ir_builder.create() + boxes = ib.buffer_ptr(boxes) + sorted_scores = ib.buffer_ptr(sorted_scores) + sorted_indices = ib.buffer_ptr(sorted_indices) + valid_count = ib.buffer_ptr(valid_count) + box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + + if isinstance(iou_threshold, float): + iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + + if isinstance(max_output_size_per_class, int): + max_output_size_per_class = tvm.tir.const(max_output_size_per_class) + + def calc_overlap(i, j, k): + offset_j = sorted_indices[i, j] * 4 + offset_k = sorted_indices[i, k] * 4 + batch_id = i // num_class + base_bbox_idx = batch_id * num_anchors * 4 + return calculate_overlap( + boxes, + base_bbox_idx + offset_j, + base_bbox_idx + offset_k, + ) + + def on_new_valid_box(ib, tid, num_current_valid_box, i, j): + with ib.if_scope(tid + 0 == 0): + box_indices[i, num_current_valid_box] = sorted_indices[i, j] + + def on_new_invalidated_box(i, k): + pass + + def needs_bbox_check(i, j, k): + return tvm.tir.const(True) + + return nms_loop( + ib, + batch_class, + num_anchors, + tvm.tir.IntImm("int32", -1), # top_k + iou_threshold, + max_output_size_per_class, + valid_count, + on_new_valid_box, + on_new_invalidated_box, + needs_bbox_check, + calc_overlap, + sorted_scores, + num_valid_boxes, + ) + + +def run_all_class_nms( + boxes, sorted_scores, sorted_indices, valid_count, max_output_size_per_class, iou_threshold, nms_loop +): + batch, num_boxes, _ = boxes.shape + batch_class = sorted_scores.shape[0] + num_class = batch_class // batch + + boxes_buf = tvm.tir.decl_buffer(boxes.shape, boxes.dtype, "boxes_buf", data_alignment=8) + sorted_scores_buf = tvm.tir.decl_buffer( + sorted_scores.shape, sorted_scores.dtype, "sorted_scores_buf", data_alignment=8 + ) + sorted_indices_buf = tvm.tir.decl_buffer( + sorted_indices.shape, sorted_indices.dtype, "sorted_indices_buf", data_alignment=8 + ) + valid_count_buf = tvm.tir.decl_buffer( + valid_count.shape, "int32", "valid_count_buf", data_alignment=4 + ) + + return te.extern( + [(batch_class, num_boxes), (1, batch_class)], + [boxes, sorted_scores, sorted_indices, valid_count], + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + outs[1], # num_valid_boxes + nms_loop + ), + dtype=["int32", "int32"], + in_buffers=[ + boxes_buf, + sorted_scores_buf, + sorted_indices_buf, + valid_count_buf, + ], + name="all_class_nms", + tag="all_class_nms", + ) From a46bd03e1bf5031c94e8c5c771934f07a01c8fab Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 15:37:52 +0900 Subject: [PATCH 20/36] simplify nms loop --- python/tvm/topi/cuda/nms.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 0a958d2d3599..006a406458e2 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -341,26 +341,20 @@ def nms_inner_loop(ib, i, j, nkeep): i = by nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + max_output_size = if_then_else(max_output_size > 0, max_output_size, nkeep) with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): # Apply nms - with ib.if_scope(max_output_size > 0): - # No need to do more iteration if we have already reached max_output_size boxes - box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") - box_idx[0] = 0 - with ib.while_loop( - tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) - ): - # Proceed to the inner loop if the box with id box_idx is still valid - with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): - nms_inner_loop(ib, i, box_idx[0], nkeep) - box_idx[0] += 1 - - with ib.else_scope(): - with ib.for_range(0, nkeep, name="j") as j: - # Proceed to the inner loop if the box j is still valid - with ib.if_scope(out_scores[i, j] > -1.0): - nms_inner_loop(ib, i, j, nkeep) + # No need to do more iteration if we have already reached max_output_size boxes + box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx[0] = 0 + with ib.while_loop( + tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) + ): + # Proceed to the inner loop if the box with id box_idx is still valid + with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): + nms_inner_loop(ib, i, box_idx[0], nkeep) + box_idx[0] += 1 with ib.if_scope(tx + 0 == 0): num_valid_boxes[i] = num_valid_boxes_local[0] From 8699a9889a3e101462805f6b689413ba8c965f30 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 4 Apr 2021 15:54:50 +0900 Subject: [PATCH 21/36] cpu all_class_nms stub --- python/tvm/topi/vision/nms.py | 205 ++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 8be62a73c09e..c682a4601f3d 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -20,7 +20,21 @@ from tvm import te from tvm.te import hybrid +from tvm.contrib import nvcc +from tvm.tir import if_then_else + from ..sort import argsort +from ..math import cast +from ..utils import ceil_div +from ..transform import reshape +from ..reduction import sum +from ..sort import sort, argsort +from ..scan import cumsum +from .nms_util import ( + binary_search, + collect_selected_indices, + run_all_class_nms, +) @hybrid.script @@ -597,3 +611,194 @@ def non_max_suppression( num_anchors=num_anchors, ) return out + + +def _nms_loop( + ib, + batch_size, + num_anchors, + top_k, + iou_threshold, + max_output_size, + valid_count, + on_new_valid_box_func, + on_new_invalidated_box_func, + needs_bbox_check_func, + calc_overlap_func, + out_scores, + num_valid_boxes, +): + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + nthread_by = batch_size + nthread_tx = max_threads + + # Some cuda architectures have smaller limit of 32K for cudaDevAttrMaxRegistersPerBlock + # vs 64K for most GPUs. Since this kernel uses many registers (around 35), the limit will + # be exceeded with 1024 threads. + target = tvm.target.Target.current(allow_none=False) + if target.kind.name == "cuda": + if nvcc.get_target_compute_version(target) in ["3.2", "5.3", "6.2"]: + nthread_tx = 512 + + by = te.thread_axis("blockIdx.y") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(by, "thread_extent", nthread_by) + ib.scope_attr(tx, "thread_extent", nthread_tx) + + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) + num_valid_boxes_local[0] = 0 + + def nms_inner_loop(ib, i, j, nkeep): + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold + on_new_valid_box_func(ib, tx, num_valid_boxes_local[0], i, j) + num_valid_boxes_local[0] += 1 + + num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) + + with ib.for_range(0, num_iter_per_thread, name="_k") as _k: + k = j + 1 + _k * nthread_tx + tx + + with ib.if_scope( + tvm.tir.all( + k < nkeep, + out_scores[i, k] > 0, # is the box k still valid? + needs_bbox_check_func(i, j, k), + ) + ): + iou = calc_overlap_func(i, j, k) + + with ib.if_scope(iou >= iou_threshold): + # invalidate the box k + out_scores[i, k] = -1.0 + on_new_invalidated_box_func(i, k) + + i = by + + nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + max_output_size = if_then_else(max_output_size > 0, max_output_size, nkeep) + + with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): + # Apply nms + # No need to do more iteration if we have already reached max_output_size boxes + box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + box_idx[0] = 0 + with ib.while_loop( + tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) + ): + # Proceed to the inner loop if the box with id box_idx is still valid + with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): + nms_inner_loop(ib, i, box_idx[0], nkeep) + box_idx[0] += 1 + + with ib.if_scope(tx + 0 == 0): + num_valid_boxes[i] = num_valid_boxes_local[0] + + with ib.else_scope(): + num_valid_boxes[i] = 0 + + return ib.get() + + +def _get_valid_box_count(scores, score_threshold): + batch_classes, num_boxes = scores.shape + + def searchsorted_ir(scores, valid_count): + ib = tvm.tir.ir_builder.create() + scores = ib.buffer_ptr(scores) + valid_count = ib.buffer_ptr(valid_count) + + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + + with ib.new_scope(): + ib.scope_attr(bx, "thread_extent", ceil_div(batch_classes, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < batch_classes): + binary_search(ib, tid, num_boxes, scores, score_threshold, valid_count) + + return ib.get() + + scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) + + return te.extern( + [(batch_classes,)], + [scores], + lambda ins, outs: searchsorted_ir(ins[0], outs[0]), + dtype=["int32"], + in_buffers=[scores_buf], + name="searchsorted", + tag="searchsorted", + ) + + +def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out): + batch_classes, num_boxes = selected_indices.shape + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + out = ib.buffer_ptr(out) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_boxes, nthread_tx) + nthread_by = batch_classes + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + + with ib.new_scope(): + idx = bx * nthread_tx + tx + idy = cast(by, "int64") + batch_id = idy // num_class + class_id = idy % num_class + with ib.if_scope(idx < num_detections[idy]): + out[row_offsets[idy] + idx, 0] = batch_id + out[row_offsets[idy] + idx, 1] = class_id + out[row_offsets[idy] + idx, 2] = cast(selected_indices[idy, idx], "int64") + + return ib.get() + + +def all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold +): + batch, num_class, num_boxes = scores.shape + + scores = reshape(scores, (batch * num_class, num_boxes)) + # TODO(masahi): CPU argsort should return both sorted values and indices + sorted_scores = sort(scores, axis=1, is_ascend=False, dtype="int32") + sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") + valid_count = _get_valid_box_count(sorted_scores, score_threshold) + + selected_indices, num_detections = run_all_class_nms( + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_boxes_per_class, + iou_threshold, + _nms_loop, + ) + + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + + num_total_detections = sum(num_detections, axis=1) + + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + + return [selected_indices, num_total_detections] From ef8d3c98e6cb74ebc41a8b614a294ab8a10f3653 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 06:02:23 +0900 Subject: [PATCH 22/36] updating ir for cpu --- python/tvm/topi/vision/nms.py | 42 +++++++++-------------------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index c682a4601f3d..f17a2fe31c4d 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -711,17 +711,8 @@ def searchsorted_ir(scores, valid_count): scores = ib.buffer_ptr(scores) valid_count = ib.buffer_ptr(valid_count) - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - - with ib.new_scope(): - ib.scope_attr(bx, "thread_extent", ceil_div(batch_classes, max_threads)) - ib.scope_attr(tx, "thread_extent", max_threads) - tid = bx * max_threads + tx - - with ib.if_scope(tid < batch_classes): - binary_search(ib, tid, num_boxes, scores, score_threshold, valid_count) + with ib.for_range(0, batch_classes, name="i") as i: + binary_search(ib, i, num_boxes, scores, score_threshold, valid_count) return ib.get() @@ -748,26 +739,14 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro row_offsets = ib.buffer_ptr(row_offsets) out = ib.buffer_ptr(out) - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - nthread_bx = ceil_div(num_boxes, nthread_tx) - nthread_by = batch_classes - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - ib.scope_attr(by, "thread_extent", nthread_by) + with ib.for_range(0, batch_classes, name="i") as i: + batch_id = i // num_class + class_id = i % num_class - with ib.new_scope(): - idx = bx * nthread_tx + tx - idy = cast(by, "int64") - batch_id = idy // num_class - class_id = idy % num_class - with ib.if_scope(idx < num_detections[idy]): - out[row_offsets[idy] + idx, 0] = batch_id - out[row_offsets[idy] + idx, 1] = class_id - out[row_offsets[idy] + idx, 2] = cast(selected_indices[idy, idx], "int64") + with ib.for_range(0, num_detections[i], name="j", kind="parallel") as j: + out[row_offsets[i] + j, 0] = batch_id + out[row_offsets[i] + j, 1] = class_id + out[row_offsets[i] + j, 2] = cast(selected_indices[i, j], "int64") return ib.get() @@ -776,9 +755,8 @@ def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ): batch, num_class, num_boxes = scores.shape - scores = reshape(scores, (batch * num_class, num_boxes)) - # TODO(masahi): CPU argsort should return both sorted values and indices + sorted_scores = sort(scores, axis=1, is_ascend=False, dtype="int32") sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") valid_count = _get_valid_box_count(sorted_scores, score_threshold) From 8400bbf29344e66d1ca8fa91377086b88fb1d6a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 06:37:06 +0900 Subject: [PATCH 23/36] working with cpu --- python/tvm/topi/vision/nms.py | 90 ++++++++------------ tests/python/topi/python/test_topi_vision.py | 3 +- 2 files changed, 37 insertions(+), 56 deletions(-) diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index f17a2fe31c4d..a8ba8362c9d7 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -628,74 +628,53 @@ def _nms_loop( out_scores, num_valid_boxes, ): - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - - with ib.new_scope(): - nthread_by = batch_size - nthread_tx = max_threads - - # Some cuda architectures have smaller limit of 32K for cudaDevAttrMaxRegistersPerBlock - # vs 64K for most GPUs. Since this kernel uses many registers (around 35), the limit will - # be exceeded with 1024 threads. - target = tvm.target.Target.current(allow_none=False) - if target.kind.name == "cuda": - if nvcc.get_target_compute_version(target) in ["3.2", "5.3", "6.2"]: - nthread_tx = 512 - - by = te.thread_axis("blockIdx.y") - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(by, "thread_extent", nthread_by) - ib.scope_attr(tx, "thread_extent", nthread_tx) - - num_valid_boxes_local = ib.allocate( - "int32", (1,), name="num_valid_boxes_local", scope="local" - ) - num_valid_boxes_local[0] = 0 - - def nms_inner_loop(ib, i, j, nkeep): - # The box j is valid, invalidate other boxes that overlap with j above iou_threshold - on_new_valid_box_func(ib, tx, num_valid_boxes_local[0], i, j) - num_valid_boxes_local[0] += 1 - - num_iter_per_thread = ceil_div(nkeep - (j + 1), nthread_tx) - - with ib.for_range(0, num_iter_per_thread, name="_k") as _k: - k = j + 1 + _k * nthread_tx + tx - - with ib.if_scope( - tvm.tir.all( - k < nkeep, - out_scores[i, k] > 0, # is the box k still valid? - needs_bbox_check_func(i, j, k), - ) - ): - iou = calc_overlap_func(i, j, k) - - with ib.if_scope(iou >= iou_threshold): - # invalidate the box k - out_scores[i, k] = -1.0 - on_new_invalidated_box_func(i, k) + def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local): + # The box j is valid, invalidate other boxes that overlap with j above iou_threshold + on_new_valid_box_func(ib, 0, num_valid_boxes_local[0], i, j) + num_valid_boxes_local[0] += 1 + + num_boxes_to_check = nkeep - (j + 1) + + with ib.for_range(0, num_boxes_to_check, name="k", kind="parallel") as _k: + k = j + 1 + _k + + with ib.if_scope( + tvm.tir.all( + k < nkeep, + out_scores[i, k] > 0, # is the box k still valid? + needs_bbox_check_func(i, j, k), + ) + ): + iou = calc_overlap_func(i, j, k) - i = by + with ib.if_scope(iou >= iou_threshold): + # invalidate the box k + out_scores[i, k] = -1.0 + on_new_invalidated_box_func(i, k) + with ib.for_range(0, batch_size, name="i") as i: nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) max_output_size = if_then_else(max_output_size > 0, max_output_size, nkeep) with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)): - # Apply nms - # No need to do more iteration if we have already reached max_output_size boxes + num_valid_boxes_local = ib.allocate( + "int32", (1,), name="num_valid_boxes_local", scope="local" + ) box_idx = ib.allocate("int32", (1,), name="box_idx", scope="local") + num_valid_boxes_local[0] = 0 box_idx[0] = 0 + + # Apply nms + # No need to do more iteration if we have already reached max_output_size boxes with ib.while_loop( tvm.tir.all(box_idx[0] < nkeep, num_valid_boxes_local[0] < max_output_size) ): # Proceed to the inner loop if the box with id box_idx is still valid with ib.if_scope(out_scores[i, box_idx[0]] > -1.0): - nms_inner_loop(ib, i, box_idx[0], nkeep) + nms_inner_loop(ib, i, box_idx[0], nkeep, num_valid_boxes_local) box_idx[0] += 1 - with ib.if_scope(tx + 0 == 0): - num_valid_boxes[i] = num_valid_boxes_local[0] + num_valid_boxes[i] = num_valid_boxes_local[0] with ib.else_scope(): num_valid_boxes[i] = 0 @@ -740,6 +719,7 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro out = ib.buffer_ptr(out) with ib.for_range(0, batch_classes, name="i") as i: + i = cast(i, "int64") batch_id = i // num_class class_id = i % num_class @@ -757,7 +737,7 @@ def all_class_non_max_suppression( batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) - sorted_scores = sort(scores, axis=1, is_ascend=False, dtype="int32") + sorted_scores = sort(scores, axis=1, is_ascend=False) sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") valid_count = _get_valid_box_count(sorted_scores, score_threshold) @@ -773,7 +753,7 @@ def all_class_non_max_suppression( row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") - num_total_detections = sum(num_detections, axis=1) + num_total_detections = sum(cast(num_detections, "int64"), axis=1) selected_indices = collect_selected_indices( num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 6a1674b91f5d..605d637ca0c9 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -66,6 +66,7 @@ } _all_class_nms_implement = { + "generic": (topi.vision.all_class_non_max_suppression, topi.generic.schedule_nms), "gpu": (topi.cuda.all_class_non_max_suppression, topi.cuda.schedule_nms), } @@ -659,7 +660,7 @@ def check_device(target): print(selected_indices.asnumpy()[:num_detections.asnumpy()[0]]) # tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) - for target in ["cuda"]: + for target in ["llvm"]: check_device(target) From dc437ff0e2786b5194de68acf575ec32f6ae1f99 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 06:40:43 +0900 Subject: [PATCH 24/36] update cpu strategy, relay op also working --- python/tvm/relay/op/strategy/generic.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 7eae8d1398ce..d28909b7840e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1074,11 +1074,10 @@ def _compute_nms(attrs, inputs, out_type): @override_native_generic_func("all_class_non_max_suppression_strategy") def all_class_nms_strategy(attrs, inputs, out_type, target): """all class nms generic strategy""" - # TODO strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), - wrap_topi_schedule(topi.cuda.schedule_nms), + wrap_compute_all_class_nms(topi.vision.all_class_non_max_suppression), + wrap_topi_schedule(topi.generic.schedule_nms), name="all_class_nms.generic", ) return strategy From ee9c4d53515638e5599822c5cc82d73762a1e663 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 06:47:49 +0900 Subject: [PATCH 25/36] fix cpplint --- include/tvm/relay/attrs/vision.h | 7 ++++--- tests/python/topi/python/test_topi_vision.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 9d3db27ff85f..005b900d5d44 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -115,9 +115,10 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { - TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, "relay.attrs.AllClassNonMaximumSuppressionAttrs") { - } +struct AllClassNonMaximumSuppressionAttrs + : public tvm::AttrsNode { + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, + "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} }; /*! \brief Attributes used in roi_align operators */ diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 605d637ca0c9..84a7d6e9e8dd 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -660,7 +660,7 @@ def check_device(target): print(selected_indices.asnumpy()[:num_detections.asnumpy()[0]]) # tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) - for target in ["llvm"]: + for target in ["llvm", "cuda"]: check_device(target) From e67eae7bc4abd2a4ef7189319622e77823a195c0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 06:51:47 +0900 Subject: [PATCH 26/36] fixing pylint --- python/tvm/relay/op/vision/nms.py | 3 +++ python/tvm/topi/cuda/nms.py | 3 +++ python/tvm/topi/vision/nms.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 0a61ca962b02..5cbcf0748fa2 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -154,6 +154,9 @@ def non_max_suppression( def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 ): + """ + TODO + """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") if not isinstance(iou_threshold, expr.Expr): diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 006a406458e2..03312079540a 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -995,6 +995,9 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ): + """ + TODO + """ batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index a8ba8362c9d7..e840b8470e3e 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -734,6 +734,9 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ): + """ + TODO + """ batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) From b4bd99526c25618106091adc27eb18eac1bfdc4d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 08:15:34 +0900 Subject: [PATCH 27/36] enable gpu test for onnx nms --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a491ed130418..44016b045382 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3579,7 +3579,7 @@ def verify_roi_align( # ONNX implementation of roi_align with max mode is incorrect, so we don't compare outputs here. -# @tvm.testing.uses_gpu +@tvm.testing.uses_gpu def test_non_max_suppression(): def verify_nms( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_dims From ed7f6aecc8470982a463f7037fa310f25c6dd1d1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 09:47:23 +0900 Subject: [PATCH 28/36] tweak parallel --- python/tvm/topi/vision/nms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index e840b8470e3e..34b38d5bfc36 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -635,7 +635,7 @@ def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local): num_boxes_to_check = nkeep - (j + 1) - with ib.for_range(0, num_boxes_to_check, name="k", kind="parallel") as _k: + with ib.for_range(0, num_boxes_to_check, name="_k", kind="parallel") as _k: k = j + 1 + _k with ib.if_scope( @@ -690,7 +690,7 @@ def searchsorted_ir(scores, valid_count): scores = ib.buffer_ptr(scores) valid_count = ib.buffer_ptr(valid_count) - with ib.for_range(0, batch_classes, name="i") as i: + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: binary_search(ib, i, num_boxes, scores, score_threshold, valid_count) return ib.get() @@ -718,12 +718,12 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro row_offsets = ib.buffer_ptr(row_offsets) out = ib.buffer_ptr(out) - with ib.for_range(0, batch_classes, name="i") as i: + with ib.for_range(0, batch_classes, name="i", kind="parallel") as i: i = cast(i, "int64") batch_id = i // num_class class_id = i % num_class - with ib.for_range(0, num_detections[i], name="j", kind="parallel") as j: + with ib.for_range(0, num_detections[i], name="j") as j: out[row_offsets[i] + j, 0] = batch_id out[row_offsets[i] + j, 1] = class_id out[row_offsets[i] + j, 2] = cast(selected_indices[i, j], "int64") From 0b263411d29c61fae9a937f98750107f04ddc6e2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 15:08:17 +0900 Subject: [PATCH 29/36] pyformat and lint --- python/tvm/relay/op/strategy/generic.py | 9 ++----- python/tvm/topi/cuda/nms.py | 5 +--- python/tvm/topi/vision/nms.py | 12 +++------ python/tvm/topi/vision/nms_util.py | 28 +++++++++++++------- tests/python/topi/python/test_topi_vision.py | 2 +- 5 files changed, 27 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index d28909b7840e..2721232b9452 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1056,17 +1056,12 @@ def nms_strategy(attrs, inputs, out_type, target): def wrap_compute_all_class_nms(topi_compute): """wrap nms topi compute""" + def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] iou_threshold = inputs[3] score_threshold = inputs[4] - return topi_compute( - inputs[0], - inputs[1], - max_output_size, - iou_threshold, - score_threshold - ) + return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold) return _compute_nms diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 03312079540a..2f48256744c1 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -276,7 +276,6 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): def _nms_loop( ib, batch_size, - num_anchors, top_k, iou_threshold, max_output_size, @@ -589,7 +588,6 @@ def needs_bbox_check(i, j, k): return _nms_loop( ib, batch_size, - num_anchors, top_k, iou_threshold, max_output_size, @@ -639,8 +637,7 @@ def _dispatch_sort(scores, ret_type="indices"): or can_use_rocthrust(target, "tvm.contrib.thrust.sort") ): return argsort_thrust(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) - else: - return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) + return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) def _get_sorted_indices(data, data_buf, score_index, score_shape): diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 34b38d5bfc36..e8a52eebd9d6 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -20,15 +20,12 @@ from tvm import te from tvm.te import hybrid -from tvm.contrib import nvcc from tvm.tir import if_then_else -from ..sort import argsort +from ..sort import sort, argsort from ..math import cast -from ..utils import ceil_div from ..transform import reshape -from ..reduction import sum -from ..sort import sort, argsort +from .. import reduction from ..scan import cumsum from .nms_util import ( binary_search, @@ -616,7 +613,6 @@ def non_max_suppression( def _nms_loop( ib, batch_size, - num_anchors, top_k, iou_threshold, max_output_size, @@ -709,7 +705,7 @@ def searchsorted_ir(scores, valid_count): def _collect_selected_indices_ir(num_class, selected_indices, num_detections, row_offsets, out): - batch_classes, num_boxes = selected_indices.shape + batch_classes, _ = selected_indices.shape ib = tvm.tir.ir_builder.create() @@ -756,7 +752,7 @@ def all_class_non_max_suppression( row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") - num_total_detections = sum(cast(num_detections, "int64"), axis=1) + num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) selected_indices = collect_selected_indices( num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 34b31786e635..47015f9a808f 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -14,11 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name +"""Common utilities used in Non-maximum suppression operators""" import tvm from tvm import te -def get_boundaries(output, box_idx): +def _get_boundaries(output, box_idx): l = tvm.te.min( output[box_idx], output[box_idx + 2], @@ -40,8 +42,8 @@ def get_boundaries(output, box_idx): def calculate_overlap(out_tensor, box_a_idx, box_b_idx): """Calculate overlap of two boxes.""" - a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx) - b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx) + a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx) + b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx) # Overlapping width and height w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l)) @@ -57,6 +59,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): def binary_search(ib, y, num_boxes, scores, score_threshold, out): + """Binary search for score_threshold on scores sorted in descending order""" lo = ib.allocate("int32", (1,), name="lo", scope="local") hi = ib.allocate("int32", (1,), name="hi", scope="local") @@ -74,6 +77,7 @@ def binary_search(ib, y, num_boxes, scores, score_threshold, out): def collect_selected_indices(num_class, selected_indices, num_detections, row_offsets, ir): + """TODO""" batch_class, num_boxes = selected_indices.shape selected_indices_buf = tvm.tir.decl_buffer( @@ -109,7 +113,7 @@ def _all_class_nms_ir( max_output_size_per_class, box_indices, num_valid_boxes, - nms_loop + nms_loop, ): ib = tvm.tir.ir_builder.create() boxes = ib.buffer_ptr(boxes) @@ -140,16 +144,15 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): box_indices[i, num_current_valid_box] = sorted_indices[i, j] - def on_new_invalidated_box(i, k): + def on_new_invalidated_box(*_): pass - def needs_bbox_check(i, j, k): + def needs_bbox_check(*_): return tvm.tir.const(True) return nms_loop( ib, batch_class, - num_anchors, tvm.tir.IntImm("int32", -1), # top_k iou_threshold, max_output_size_per_class, @@ -164,8 +167,15 @@ def needs_bbox_check(i, j, k): def run_all_class_nms( - boxes, sorted_scores, sorted_indices, valid_count, max_output_size_per_class, iou_threshold, nms_loop + boxes, + sorted_scores, + sorted_indices, + valid_count, + max_output_size_per_class, + iou_threshold, + nms_loop, ): + """TODO""" batch, num_boxes, _ = boxes.shape batch_class = sorted_scores.shape[0] num_class = batch_class // batch @@ -196,7 +206,7 @@ def run_all_class_nms( max_output_size_per_class, outs[0], # box_indices outs[1], # num_valid_boxes - nms_loop + nms_loop, ), dtype=["int32", "int32"], in_buffers=[ diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 84a7d6e9e8dd..93cb02e8c6a7 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -657,7 +657,7 @@ def check_device(target): f = tvm.build(s, [boxes, scores, out[0], out[1]], target) f(tvm_boxes, tvm_scores, selected_indices, num_detections) - print(selected_indices.asnumpy()[:num_detections.asnumpy()[0]]) + print(selected_indices.asnumpy()[: num_detections.asnumpy()[0]]) # tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) for target in ["llvm", "cuda"]: From 236132175329250a776bd5fba7b3ceb5d9ffc621 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Apr 2021 16:56:35 +0900 Subject: [PATCH 30/36] fix relay nms test --- tests/python/relay/test_op_level5.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 2d6c8b50fd37..926650762643 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -371,8 +371,6 @@ def verify_nms( ) if isinstance(z_indices, relay.expr.TupleWrapper): z_indices = z_indices.astuple() - assert "iou_threshold" in z.astext() - assert "iou_threshold" in z_indices.astext() zz = run_infer_type(z) zz_indices = run_infer_type(z_indices) assert zz.checked_type == relay.ty.TensorType(dshape, "float32") From 004145a86e7cf6f897a57305c3247e2a64ad85e8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 8 Apr 2021 18:31:51 +0900 Subject: [PATCH 31/36] doc update for cpp relay --- src/relay/op/vision/nms.cc | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 50e1ada9dcd6..53cd71745d5b 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -179,17 +179,19 @@ TVM_REGISTER_GLOBAL("relay.op.vision._make.all_class_non_max_suppression") .set_body_typed(MakeAllClassNMS); RELAY_REGISTER_OP("vision.all_class_non_max_suppression") - .describe(R"doc(Non-maximum suppression. The input boxes should -be in the format of [class_id, score, left, top, right, bottom] -or [score, left, top, right, bottom]. Set id_index to be -1 to -ignore class_id axis. + .describe(R"doc(Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately )doc" TVM_ADD_FILELINE) .set_num_inputs(5) - .add_argument("data", "Tensor", "Input data.") - .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") - .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.") - .add_argument("max_output_size", "Tensor", "Max number of output valid boxes.") - .add_argument("iou_threshold", "Tensor", "Threshold for box overlap.") + .add_argument("boxes", "Tensor", "The input boxes in the format [batch, num_boxes, 4].") + .add_argument("scores", "Tensor", + "Scores for each box and class in the format [batch, num_classes, num_boxes].") + .add_argument("max_output_boxes_per_class", "Tensor", + "The maximum number of output boxes per class.") + .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the overlap test.") + .add_argument("score_threshold", "Tensor", + "The score threshold to filter out low score boxes early.") .set_support_level(5) .add_type_rel("AllClassNMS", AllClassNMSRel); From d207c4d7e83a26ad3136b07f4934ac5bde17e457 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 8 Apr 2021 18:53:32 +0900 Subject: [PATCH 32/36] updating tests --- tests/python/relay/test_op_level5.py | 96 ++++++++++++++++++++ tests/python/topi/python/test_topi_vision.py | 17 ++-- 2 files changed, 107 insertions(+), 6 deletions(-) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 926650762643..eca0889feb8d 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1362,6 +1362,101 @@ def verify_batch_to_space_nd(dshape, block_shape, crops): verify_batch_to_space_nd([8, 1, 3, 1], [2, 2], [[0, 0], [2, 0]]) +@tvm.testing.uses_gpu +def test_all_class_non_max_suppression(): + def verify_all_class_non_max_suppression( + boxes_np, + scores_np, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected_indices, + ): + dshape = boxes_np.shape + batch, num_boxes, _ = dshape + _, num_class, _ = scores_np.shape + + boxes = relay.var("boxes", relay.ty.TensorType(dshape, "float32")) + scores = relay.var("scores", relay.ty.TensorType(scores_np.shape, "float32")) + + out = relay.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + ) + + func_indices = relay.Function([boxes, scores], out) + func_indices = run_infer_type(func_indices) + + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "vm"]: + intrp = relay.create_executor(kind, device=dev, target=target) + selected_indices, num_detections = intrp.evaluate(func_indices)(boxes_np, scores_np) + tvm_res = selected_indices.asnumpy()[: num_detections.asnumpy()[0]] + np.testing.assert_equal(tvm_res, expected_indices) + + boxes = np.array( + [ + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.0, 0.0, 0.5, 0.5], + [0.5, 0.5, 0.9, 0.9], + [0.5, 0.5, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.3, 0.3], + [0.0, 0.0, 0.4, 0.4], + [0.5, 0.5, 0.95, 0.95], + [0.5, 0.5, 0.96, 0.96], + [0.5, 0.5, 1.0, 1.0], + ], + ] + ).astype("float32") + + scores = np.array( + [ + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]], + ] + ).astype("float32") + + max_output_boxes_per_class = 2 + iou_threshold = 0.8 + score_threshold = 0.0 + + expected = [] + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + ) + + boxes = np.array( + [ + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, -0.1, 1.0, 0.9], + [0.0, 10.0, 1.0, 11.0], + [0.0, 10.1, 1.0, 11.1], + [0.0, 100.0, 1.0, 101.0], + ] + ] + ).astype(np.float32) + scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32) + max_output_boxes_per_class = 3 + iou_threshold = 0.5 + score_threshold = 0.4 + + expected = [] + + verify_all_class_non_max_suppression( + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected + ) + + if __name__ == "__main__": test_resize_infer_type() test_resize() @@ -1384,3 +1479,4 @@ def verify_batch_to_space_nd(dshape, block_shape, crops): test_affine_grid() test_grid_sample() test_space_to_batch_nd() + test_all_class_non_max_suppression() diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 93cb02e8c6a7..be8368fde3a3 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -629,7 +629,7 @@ def test_proposal(): def verify_all_class_non_max_suppression( - boxes_np, scores_np, max_output_boxes_per_class, iou_threshold, score_threshold + boxes_np, scores_np, max_output_boxes_per_class, iou_threshold, score_threshold, expected_indices ): dshape = boxes_np.shape batch, num_boxes, _ = dshape @@ -657,10 +657,11 @@ def check_device(target): f = tvm.build(s, [boxes, scores, out[0], out[1]], target) f(tvm_boxes, tvm_scores, selected_indices, num_detections) - print(selected_indices.asnumpy()[: num_detections.asnumpy()[0]]) - # tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) - for target in ["llvm", "cuda"]: + tvm_res = selected_indices.asnumpy()[: num_detections.asnumpy()[0]] + np.testing.assert_equal(tvm_res, expected_indices) + + for target in ["llvm", "cuda", "opencl", "vulkan"]: check_device(target) @@ -696,8 +697,10 @@ def test_all_class_non_max_suppression(): iou_threshold = 0.8 score_threshold = 0.0 + expected = [] + verify_all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected ) boxes = np.array( @@ -717,8 +720,10 @@ def test_all_class_non_max_suppression(): iou_threshold = 0.5 score_threshold = 0.4 + expected = [] + verify_all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected ) From 05fa4151b2b61fc66fc28eee114ef58312da8e4f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 8 Apr 2021 19:09:11 +0900 Subject: [PATCH 33/36] updated tests --- python/tvm/relay/op/strategy/cuda.py | 2 +- python/tvm/relay/op/strategy/generic.py | 2 +- tests/python/relay/test_op_level5.py | 17 ++++++------ tests/python/topi/python/test_topi_vision.py | 27 ++++++++++++-------- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 198b98d7f039..ef85a37c2175 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -948,7 +948,7 @@ def nms_strategy_cuda(attrs, inputs, out_type, target): @all_class_nms_strategy.register(["cuda", "gpu"]) def all_class_nms_strategy_cuda(attrs, inputs, out_type, target): - """nms cuda strategy""" + """all class nms cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression), diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 2721232b9452..362ab9be0fec 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1055,7 +1055,7 @@ def nms_strategy(attrs, inputs, out_type, target): def wrap_compute_all_class_nms(topi_compute): - """wrap nms topi compute""" + """wrap all class nms topi compute""" def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index eca0889feb8d..2a71115f2929 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -21,7 +21,6 @@ import tvm from tvm import te from tvm import relay -from tvm.relay import transform from tvm.relay.testing import run_infer_type import tvm.topi.testing import tvm.testing @@ -1373,8 +1372,6 @@ def verify_all_class_non_max_suppression( expected_indices, ): dshape = boxes_np.shape - batch, num_boxes, _ = dshape - _, num_class, _ = scores_np.shape boxes = relay.var("boxes", relay.ty.TensorType(dshape, "float32")) scores = relay.var("scores", relay.ty.TensorType(scores_np.shape, "float32")) @@ -1387,13 +1384,13 @@ def verify_all_class_non_max_suppression( score_threshold, ) - func_indices = relay.Function([boxes, scores], out) - func_indices = run_infer_type(func_indices) + func = relay.Function([boxes, scores], out.astuple()) + func = run_infer_type(func) for target, dev in tvm.testing.enabled_targets(): - for kind in ["graph", "vm"]: + for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) - selected_indices, num_detections = intrp.evaluate(func_indices)(boxes_np, scores_np) + selected_indices, num_detections = intrp.evaluate(func)(boxes_np, scores_np) tvm_res = selected_indices.asnumpy()[: num_detections.asnumpy()[0]] np.testing.assert_equal(tvm_res, expected_indices) @@ -1427,7 +1424,9 @@ def verify_all_class_non_max_suppression( iou_threshold = 0.8 score_threshold = 0.0 - expected = [] + expected = np.array( + [[0, 0, 4], [0, 0, 2], [0, 1, 4], [0, 1, 2], [1, 0, 4], [1, 0, 1], [1, 1, 4], [1, 1, 1]] + ) verify_all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected @@ -1450,7 +1449,7 @@ def verify_all_class_non_max_suppression( iou_threshold = 0.5 score_threshold = 0.4 - expected = [] + expected = np.array([[0, 0, 3], [0, 0, 0]]) verify_all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index be8368fde3a3..86469919cfa6 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -629,7 +629,12 @@ def test_proposal(): def verify_all_class_non_max_suppression( - boxes_np, scores_np, max_output_boxes_per_class, iou_threshold, score_threshold, expected_indices + boxes_np, + scores_np, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected_indices, ): dshape = boxes_np.shape batch, num_boxes, _ = dshape @@ -697,7 +702,9 @@ def test_all_class_non_max_suppression(): iou_threshold = 0.8 score_threshold = 0.0 - expected = [] + expected = np.array( + [[0, 0, 4], [0, 0, 2], [0, 1, 4], [0, 1, 2], [1, 0, 4], [1, 0, 1], [1, 1, 4], [1, 1, 1]] + ) verify_all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected @@ -720,7 +727,7 @@ def test_all_class_non_max_suppression(): iou_threshold = 0.5 score_threshold = 0.4 - expected = [] + expected = np.array([[0, 0, 3], [0, 0, 0]]) verify_all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected @@ -728,11 +735,11 @@ def test_all_class_non_max_suppression(): if __name__ == "__main__": - # test_get_valid_counts() - # test_multibox_prior() - # test_multibox_detection() - # test_roi_align() - # test_roi_pool() - # test_proposal() - # test_non_max_suppression() + test_get_valid_counts() + test_multibox_prior() + test_multibox_detection() + test_roi_align() + test_roi_pool() + test_proposal() + test_non_max_suppression() test_all_class_non_max_suppression() From 6d314deb49898967361f62c194c9ea3685961dfd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 8 Apr 2021 19:43:40 +0900 Subject: [PATCH 34/36] fix converting score_threshold to Expr --- python/tvm/relay/op/vision/nms.py | 41 +++++++++++++++++++- tests/python/relay/test_op_level5.py | 4 +- tests/python/topi/python/test_topi_vision.py | 2 +- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 5cbcf0748fa2..615ff6bfa4e7 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -154,13 +154,50 @@ def non_max_suppression( def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 ): - """ - TODO + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : relay.Expr + 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. + The last dimension should be in format of + [class_id, score, box_left, box_top, box_right, box_bottom] + or [score, box_left, box_top, box_right, box_bottom]. It could + be the second output out_tensor of get_valid_counts. + + scores: relay.Expr + 2-D tensor with shape [batch_size, num_anchors], represents + the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the + second dimension are like the output of arange(num_anchors) + if get_valid_counts is not used before non_max_suppression. + + max_output_boxes_per_class : int or relay.Expr, optional + Max number of output valid boxes for each instance. + Return all valid boxes if the value of max_output_size is less than 0. + + iou_threshold : float or relay.Expr, optionaIl + IoU test threshold + + score_threshold : float or relay.Expr, optional + Score threshold to filter out low score boxes early + + Returns + ------- + out : relay.Tuple + The output is a relay.Tuple of two 2-D tensors, with + shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively. """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") if not isinstance(iou_threshold, expr.Expr): iou_threshold = expr.const(iou_threshold, "float32") + if not isinstance(score_threshold, expr.Expr): + score_threshold = expr.const(score_threshold, "float32") + out = _make.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 2a71115f2929..466b1b19a582 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -1371,9 +1371,7 @@ def verify_all_class_non_max_suppression( score_threshold, expected_indices, ): - dshape = boxes_np.shape - - boxes = relay.var("boxes", relay.ty.TensorType(dshape, "float32")) + boxes = relay.var("boxes", relay.ty.TensorType(boxes_np.shape, "float32")) scores = relay.var("scores", relay.ty.TensorType(scores_np.shape, "float32")) out = relay.vision.all_class_non_max_suppression( diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 86469919cfa6..d26026265864 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -475,7 +475,7 @@ def check_device(target): tvm_val = tvm_b.asnumpy() tvm.testing.assert_allclose(tvm_val, b_np, rtol=1e-3, atol=1e-4) - for target in ["cuda"]: + for target in ["llvm", "cuda", "opencl"]: check_device(target) From 56531f7229d3cdb036baa2c8058a485d12a91c04 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 10 Apr 2021 08:14:23 +0900 Subject: [PATCH 35/36] update doc --- python/tvm/relay/op/vision/nms.py | 25 +++++------ python/tvm/topi/cuda/nms.py | 33 ++++++++++++++- python/tvm/topi/vision/nms.py | 33 ++++++++++++++- python/tvm/topi/vision/nms_util.py | 68 ++++++++++++++++++++++++++++-- 4 files changed, 137 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 615ff6bfa4e7..3f829e0b1cc7 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -161,23 +161,13 @@ def all_class_non_max_suppression( Parameters ---------- boxes : relay.Expr - 3-D tensor with shape [batch_size, num_anchors, 6] - or [batch_size, num_anchors, 5]. - The last dimension should be in format of - [class_id, score, box_left, box_top, box_right, box_bottom] - or [score, box_left, box_top, box_right, box_bottom]. It could - be the second output out_tensor of get_valid_counts. + 3-D tensor with shape (batch_size, num_boxes, 4) scores: relay.Expr - 2-D tensor with shape [batch_size, num_anchors], represents - the index of box in original data. It could be the third - output out_indices of get_valid_counts. The values in the - second dimension are like the output of arange(num_anchors) - if get_valid_counts is not used before non_max_suppression. + 3-D tensor with shape (batch_size, num_classes, num_boxes) max_output_boxes_per_class : int or relay.Expr, optional - Max number of output valid boxes for each instance. - Return all valid boxes if the value of max_output_size is less than 0. + The maxinum number of output selected boxes per class iou_threshold : float or relay.Expr, optionaIl IoU test threshold @@ -188,8 +178,13 @@ def all_class_non_max_suppression( Returns ------- out : relay.Tuple - The output is a relay.Tuple of two 2-D tensors, with - shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively. + The output is a relay.Tuple of two tensors, the first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected boxes. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, + in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2f48256744c1..2789452cc10b 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -992,8 +992,37 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ): - """ - TODO + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + + scores: tvm.te.Tensor + 3-D tensor with shape (batch_size, num_classes, num_boxes) + + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + + score_threshold : float or tvm.te.Tensor, optional + Score threshold to filter out low score boxes early + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors, the first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. """ batch, num_class, num_boxes = scores.shape diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index e8a52eebd9d6..744c5ef7feda 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -730,8 +730,37 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro def all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold ): - """ - TODO + """Non-maximum suppression operator for object detection, corresponding to ONNX + NonMaxSuppression and TensorFlow combined_non_max_suppression. + NMS is performed for each class separately. + + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + + scores: tvm.te.Tensor + 3-D tensor with shape (batch_size, num_classes, num_boxes) + + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + + score_threshold : float or tvm.te.Tensor, optional + Score threshold to filter out low score boxes early + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors, the first is `indices` of size + `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor + `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of + `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` + rows are valid. """ batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 47015f9a808f..aae4c29fd875 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -77,7 +77,34 @@ def binary_search(ib, y, num_boxes, scores, score_threshold, out): def collect_selected_indices(num_class, selected_indices, num_detections, row_offsets, ir): - """TODO""" + """Collect selected indices from the core NMS loop into one linear output + + Parameters + ---------- + num_class : int + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + + num_detections tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes selected by the core NMS loop, per batch and class + + row_offsets tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), this should be the exclusive scan + of num_detections + + ir : function + A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : tvm.te.Tensor + The output is indices of size (batch_size * num_class* num_boxes , 3). + Rows of indices are ordered such that selected boxes from batch 0, class 0 come + first, in descending of scores, followed by boxes from batch 0, class 1 etc. + """ batch_class, num_boxes = selected_indices.shape selected_indices_buf = tvm.tir.decl_buffer( @@ -175,7 +202,42 @@ def run_all_class_nms( iou_threshold, nms_loop, ): - """TODO""" + """The core all class NMS routine + + Parameters + ---------- + boxes : tvm.te.Tensor + 3-D tensor with shape (batch_size, num_boxes, 4) + + sorted_scores: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + One of the outputs from argsort + + sorted_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes) + The other output from argsort + + valid_count: tvm.te.Tensor + 1-D tensor with shape (batch_size * num_classes,), representing + the number of boxes whose score is above score_threshold, per batch and class + + max_output_boxes_per_class : int or tvm.te.Tensor, optional + The maxinum number of output selected boxes per class + + iou_threshold : float or tvm.te.Tensor, optionaIl + IoU test threshold + + nms_loop : function + A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors, the first is indices of size + (batch_size * num_class, num_boxes) and the second is a tensor + num_selected_boxes of shape (batch_size * num_class,) representing the total number of + selected boxes per batch and class. + """ batch, num_boxes, _ = boxes.shape batch_class = sorted_scores.shape[0] num_class = batch_class // batch @@ -205,7 +267,7 @@ def run_all_class_nms( iou_threshold, max_output_size_per_class, outs[0], # box_indices - outs[1], # num_valid_boxes + outs[1], # num_selected_boxes nms_loop, ), dtype=["int32", "int32"], From b1749271bd677dbe8e6e431c512e114c1f5fd5bd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 10 Apr 2021 09:17:06 +0900 Subject: [PATCH 36/36] doc fix --- python/tvm/topi/vision/nms_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index aae4c29fd875..1147b1687783 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -96,7 +96,7 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of of num_detections ir : function - A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py Returns -------