Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 7f5c76d commit 4567417
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
7 changes: 5 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,12 @@ def _impl(inputs, attr, params, mod):
max_total_size.data.numpy().item(),
output_format="tensorflow",
)
# return _expr.TupleWrapper(
# _expr.Tuple([indices, num_detections, num_detections, num_detections]), 4
# )

nmsed_box_indices = _op.take(indices, _op.const(0), axis=2)
nmsed_classes = _op.take(indices, _op.const(1), axis=2)
nmsed_box_indices = _op.take(indices, _op.const(1), axis=2)
nmsed_classes = _op.take(indices, _op.const(0), axis=2)
nmsed_boxes = _op.gather_nd(boxes, _op.expand_dims(nmsed_box_indices, axis=0), batch_dims=1)

indices_dims = len(_infer_shape(indices, mod))
Expand Down
19 changes: 11 additions & 8 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..math import cast
from .. import reduction
from ..broadcast import minimum
from ..transform import reshape, strided_slice, gather_nd, expand_dims
from ..transform import reshape, strided_slice, gather_nd, expand_dims, squeeze
from ..vision.nms_util import (
calculate_overlap,
binary_search,
Expand Down Expand Up @@ -1149,10 +1149,13 @@ def all_class_non_max_suppression(
selected_indices, selected_scores = collect_selected_indices_tf(
selected_indices, selected_scores, num_detections_per_batch, row_offsets
)
selected_scores = strided_slice(
selected_scores, begin=[0, 0], end=[batch, reduction.max(num_total_detections)]
)
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")
final_indices = gather_nd(selected_indices, expand_dims(topk_indices, axis=0), batch_dims=1)
num_detections = minimum(num_total_detections, max_detection_per_batch)
return [final_indices, num_detections]
# selected_scores = strided_slice(
# selected_scores, begin=[0, 0], end=[batch, reduction.max(num_total_detections)]
# )
topk_indices = topk(selected_scores, k=max_detection_per_batch, axis=1, ret_type="indices")[0]
topk_indices = expand_dims(topk_indices, axis=0)
final_indices = gather_nd(selected_indices, topk_indices, batch_dims=1)
print(final_indices.shape)
print(num_total_detections.shape)
# num_detections = minimum(num_total_detections, max_detection_per_batch)
return [final_indices, num_total_detections]

0 comments on commit 4567417

Please sign in to comment.