diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 8234d4c69486..36e8dd5958c7 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -550,26 +550,6 @@ inline Array split(const Tensor& x, Array split_indices, int a return result; } -inline te::Tensor strided_slice_compute_common(const te::Tensor& x, - const Array& out_shape, - const Array& begin, - const Array& strides, - const Array& axes, const std::string& name, - const std::string& tag) { - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); - for (size_t i = 0; i < axes.size(); ++i) { - PrimExpr ind = indices[axes[i]] * strides[i] + begin[i]; - real_indices.Set(axes[i], ind); - } - return x(real_indices); - }, - name, tag); -} - inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, const Array& end, const Array& strides, std::string name = "T_dynamic_strided_slice", @@ -645,34 +625,6 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); } -inline Tensor strided_slice_dynamic_input(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - std::string slice_mode = "end", - std::string name = "T_strided_slice_dynamic_input", - std::string tag = kInjective) { - size_t src_tensor_dim = x->shape.size(); - ICHECK(begin.size() == src_tensor_dim) - << "for dynamic inputs, len(begin) must equal the input dimension"; - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - Array begin_expr, end_expr, strides_expr; - Array axes; - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_i = begin[i]->value; - if (begin_i < 0) { - begin_i += topi::detail::GetConstInt(x->shape[i]); - } - begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - (i < strides.size() ? strides[i]->value : 1))); - axes.push_back(i); - } - return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag); -} - inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, const Array& end, const Array& strides, const Array& axes, std::string slice_mode = "end", @@ -729,34 +681,56 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg Array begin_expr, strides_expr; for (size_t i = 0; i < axes.size(); ++i) { int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; - ICHECK(x->shape[axes[i]]->IsInstance()) - << "Input shape at axis " << axes[i] << " is not static"; - int64_t dim_i = GetConstInt(x->shape[axes[i]]); - int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i; - // transform negative indices to positive value, clips on the correct range - auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) { - if (index < 0) { - index += dim_i; - } - return std::min(std::max(index, begin_range), end_range); - }; + if (x->shape[axes[i]]->IsInstance()) { + int64_t dim_i = GetConstInt(x->shape[axes[i]]); + int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i; + // transform negative indices to positive value, clips on the correct range + auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) { + if (index < 0) { + index += dim_i; + } + return std::min(std::max(index, begin_range), end_range); + }; - int64_t begin_i = index_canonicalization(begin_vec[i]); - int64_t end_i = index_canonicalization(end_vec[i]); + int64_t begin_i = index_canonicalization(begin_vec[i]); + int64_t end_i = index_canonicalization(end_vec[i]); - int interval = std::abs(end_i - begin_i); - int slice_size = - static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); - ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; + int interval = std::abs(end_i - begin_i); + int slice_size = + static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); + ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) + << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] + << "] is invalid for axis=" << i; - begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); - out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); + begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); + out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); + } else { + auto idim = x->shape[axes[i]]; + auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]); + auto s = strides[i]->value; + if (s < 0) { + b = tvm::min(b, idim - 1); + } else { + b = tvm::if_then_else(b < 0, 0, b); + } + out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype)); + begin_expr.push_back(b); + } + strides_expr.push_back(make_const(strides[i].dtype(), stride_vec[i])); } - return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag); + + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i]; + real_indices.Set(axes[i], ind); + } + return x(real_indices); + }, + name, tag); } /*! diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a1943e8cbfe2..5c80b73b9a05 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2739,9 +2739,6 @@ Array StridedSliceCompute(const Attrs& attrs, const Arrayaxes.value(); return Array{ topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)}; - } else if (IsDynamic(out_type)) { - return Array{ - topi::strided_slice_dynamic_input(inputs[0], begin, end, strides, param->slice_mode)}; } return Array{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; } diff --git a/src/topi/transform.cc b/src/topi/transform.cc index e19b6da11064..7c6e491dcc26 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -183,11 +183,7 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* Array begin_static = args[1]; Array end_static = args[2]; Array strides_static = args[3]; - if (IsConstIntArray(x->shape)) { - *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); - } else { - *rv = strided_slice_dynamic_input(x, begin_static, end_static, strides_static, slice_mode); - } + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); } else { *rv = dynamic_strided_slice(x, begin, end, strides); }