Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[THRUST] Faster multi dimensional argsort by segmented sort #7195

Merged
merged 4 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,11 +813,9 @@ def non_max_suppression(
if (
target
and target.kind.name == "cuda"
and tvm.get_global_func("tvm.contrib.thrust.sort_nms", allow_missing=True)
and tvm.get_global_func("tvm.contrib.thrust.sort", allow_missing=True)
):
sort_tensor = argsort_thrust(
score_tensor, valid_count=None, axis=1, is_ascend=False, dtype=valid_count_dtype
)
sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype)
else:
sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype=valid_count_dtype)

Expand Down
73 changes: 2 additions & 71 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,68 +409,6 @@ def sort_by_key_ir(
)


def argsort_nms_thrust(data, valid_count, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.

Parameters
----------
data: tvm.te.Tensor
The input array.

valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.

axis : int, optional
Axis long which to sort the input tensor.

is_ascend : boolean, optional
Whether to sort in ascending or descending order.

dtype : string, optional
DType of the output indices.

Returns
-------
out : tvm.te.Tensor
The output of this function.
"""
ndim = len(data.shape)
if axis < 0:
axis = ndim + axis
if axis != ndim - 1:
# Prepare for sorting along axis -1.
axes = swap(list(range(ndim)), axis)
data = transpose(data, axes)

data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
valid_count_buf = tvm.tir.decl_buffer(
valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4
)
out_bufs = [
tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8),
tvm.tir.decl_buffer(data.shape, "int32", "indices_buf", data_alignment=8),
]
out = te.extern(
[data.shape, data.shape],
[data, valid_count],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.thrust.sort_nms", ins[0], ins[1], outs[0], outs[1], is_ascend
),
in_buffers=[data_buf, valid_count_buf],
out_buffers=out_bufs,
dtype=[data.dtype, "int32"],
name="nms_argsort_gpu",
tag="nms_argsort_gpu",
)

if axis != ndim - 1:
axes = swap(list(range(ndim)), axis)
out = [transpose(o, axes) for o in out]

return out[1]


def sort(data, axis=-1, is_ascend=1):
"""Performs sorting along the given axis and returns an array of
sorted values with the same shape as the input data.
Expand Down Expand Up @@ -602,7 +540,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
return out


def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
def argsort_thrust(data, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.

Expand All @@ -611,9 +549,6 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
data: tvm.te.Tensor
The input array.

valid_count : tvm.te.Tensor, optional
The number of valid elements to be sorted.

axis : int, optional
Axis long which to sort the input tensor.

Expand All @@ -628,11 +563,7 @@ def argsort_thrust(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"
out : tvm.te.Tensor
The output of this function.
"""
if valid_count is not None:
out = argsort_nms_thrust(data, valid_count, axis, is_ascend, dtype)
else:
out = topk_thrust(data, 0, axis, "indices", is_ascend, dtype)
return out
return topk_thrust(data, 0, axis, "indices", is_ascend, dtype)


def schedule_sort(outs):
Expand Down
117 changes: 68 additions & 49 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
*/

#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/gather.h>

#include <tvm/runtime/registry.h>
#include <dlpack/dlpack.h>
Expand All @@ -41,85 +43,122 @@ void thrust_sort(DLTensor* input,
DLTensor* out_values,
DLTensor* out_indices,
bool is_ascend,
const std::function<int(int)> &get_sort_len) {
int n_values) {
thrust::device_ptr<DataType> data_ptr(static_cast<DataType *>(input->data));
thrust::device_ptr<DataType> values_ptr(static_cast<DataType *>(out_values->data));
thrust::device_ptr<IndicesType> indices_ptr(static_cast<IndicesType *>(out_indices->data));

int n_values = input->shape[input->ndim - 1];
int n_iter = 1;
for (int i = 0; i < input->ndim - 1; ++i) {
n_iter *= input->shape[i];
size_t size = 1;
for (int i = 0; i < input->ndim; ++i) {
size *= input->shape[i];
}
thrust::copy(data_ptr, data_ptr + size, values_ptr);

thrust::copy(data_ptr, data_ptr + n_iter * n_values, values_ptr);

for (int i = 0 ; i < n_iter; ++i) {
n_values = get_sort_len(i);
if (size == static_cast<size_t>(input->shape[input->ndim - 1])) {
// A fast path for single segment case
thrust::sequence(indices_ptr, indices_ptr + n_values);
if (is_ascend) {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr);
} else {
thrust::sort_by_key(values_ptr, values_ptr + n_values, indices_ptr,
thrust::greater<DataType>());
}
values_ptr += n_values;
indices_ptr += n_values;
} else {
// segmented sort by key
// Follow the back-to-back stable_sort_by_key strategy explained below
// https://groups.google.com/g/thrust-users/c/BoLsxO6b4FY
thrust::device_vector<int64_t> argsort_order(size);
thrust::sequence(argsort_order.begin(), argsort_order.end());

// First, sort values and store the sorted order in argsort_order.
if (is_ascend) {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin());
} else {
thrust::stable_sort_by_key(values_ptr, values_ptr + size, argsort_order.begin(),
thrust::greater<DataType>());
}

// The following is to create the indices array 0, 1, 2, 0, 1, 2 ... 0, 1, 2
// without materializing it
auto counting_iter = thrust::counting_iterator<int64_t>(0);
auto linear_index_to_sort_axis_index = [n_values] __host__ __device__(int64_t i) {
return i % n_values;
}; // NOLINT(*)
auto init_indices_iter = thrust::make_transform_iterator(counting_iter,
linear_index_to_sort_axis_index);

// This will reorder indices 0, 1, 2 ... in the sorted order of values_ptr
thrust::gather(argsort_order.begin(), argsort_order.end(), init_indices_iter, indices_ptr);

thrust::device_vector<int> segment_ids(size);
auto linear_index_to_segment_id = [n_values] __host__ __device__(int64_t i) {
return i / n_values;
}; // NOLINT(*)
// We also reorder segment indices 0, 0, 0, 1, 1, 1 ... in the order of values_ptr
thrust::transform(argsort_order.begin(), argsort_order.end(), segment_ids.begin(),
linear_index_to_segment_id);

// The second sort key-ed by segment_ids would bring segment_ids back to 0, 0, 0, 1, 1, 1 ...
// values_ptr and indices_ptr will also be sorted in the order of segmend_ids above
// Since sorting has been done in a stable way, relative orderings of values and indices
// in the segment do not change and hence they remain sorted.
auto key_val_zip = thrust::make_zip_iterator(thrust::make_tuple(values_ptr, indices_ptr));
thrust::stable_sort_by_key(segment_ids.begin(), segment_ids.end(), key_val_zip);
}
}

void thrust_sort_common(DLTensor* input,
DLTensor* values_out,
DLTensor* indices_out,
bool is_ascend,
const std::function<int(int)> &get_sort_len,
int sort_len,
std::string data_dtype,
std::string out_dtype) {
if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "float64") {
if (out_dtype == "int32") {
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int32") {
if (out_dtype == "int32") {
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
} else if (data_dtype == "int64") {
if (out_dtype == "int32") {
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "int64") {
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, int64_t>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float32") {
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, float>(input, values_out, indices_out, is_ascend, sort_len);
} else if (out_dtype == "float64") {
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, get_sort_len);
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, sort_len);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
Expand All @@ -128,25 +167,6 @@ void thrust_sort_common(DLTensor* input,
}
}

TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_nms")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 5);
DLTensor* input = args[0];
DLTensor* valid_count = args[1];
DLTensor* values_out = args[2];
DLTensor* indices_out = args[3];
bool is_ascend = args[4];

auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = DLDataType2String(indices_out->dtype);

thrust::device_ptr<int> valid_count_ptr(static_cast<int *>(valid_count->data));
auto get_sort_len = [&valid_count_ptr](int i) { return valid_count_ptr[i]; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
data_dtype, out_dtype);
});


TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ICHECK_GE(args.num_args, 4);
Expand All @@ -159,8 +179,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
auto out_dtype = DLDataType2String(indices_out->dtype);

int n_values = input->shape[input->ndim - 1];
auto get_sort_len = [=](int i) { return n_values; };
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
data_dtype, out_dtype);
});

Expand Down