From 66e4c97a15716b94f90f13bf423676c9127b8625 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Jul 2022 07:00:32 -0700 Subject: [PATCH 1/8] Compute common type for shape elements in BroadcastHelper The corresponding dimensions in the input/output tensors in a broadcast operations may have the same value, but different types (e.g. int32 vs int64). When the broadcast helper tries to unify the dimensions it also needs to compute the common type to hold the dimension. --- include/tvm/topi/detail/broadcast.h | 30 +++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index 5c701825840c..c861fbb71b2a 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -42,6 +42,12 @@ struct BroadcastHelper { std::deque vars2; }; +static inline DataType CommonType(DataType type1, DataType type2) { + ICHECK(type1.is_scalar() && type2.is_scalar()); + ICHECK(type1.code() == type2.code()); + return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1); +} + inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, const tvm::Array& shape2) { BroadcastHelper bh; @@ -49,32 +55,40 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, int s2_size = shape2.size(); tvm::PrimExpr one(1); int i; + + auto cast_if_needed = [](DataType to_type, PrimExpr expr) { + return to_type != expr.dtype() ? cast(to_type, expr) : expr; + }; + for (i = 1; i <= std::min(s1_size, s2_size); ++i) { // TODO(@icemelon9): Need to revisit this part const IntImmNode* static_size1 = shape1[s1_size - i].as(); const IntImmNode* static_size2 = shape2[s2_size - i].as(); - bh.all_vars.push_front(tvm::tir::Var()); + DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype()); + + bh.all_vars.push_front(tvm::tir::Var("dim", common_type)); if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { - bh.common_shape.push_front(shape1[s1_size - i]); + bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i])); bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) { ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i])); - bh.common_shape.push_front(shape2[s2_size - i]); + bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i])); bh.vars2.push_front(bh.all_vars[0]); } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) { - bh.common_shape.push_front(shape1[s1_size - i]); + bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i])); bh.vars1.push_front(bh.all_vars[0]); } else if (!static_size1 && !static_size2) { - bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i])); + bh.common_shape.push_front( + cast_if_needed(common_type, max(shape1[s1_size - i], shape2[s2_size - i]))); bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else if (!static_size1) { - bh.common_shape.push_front(shape2[s2_size - i]); + bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i])); bh.vars2.push_front(bh.all_vars[0]); bh.vars1.push_front(bh.all_vars[0]); } else if (!static_size2) { - bh.common_shape.push_front(shape1[s1_size - i]); + bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i])); bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else { @@ -89,7 +103,7 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, auto& shape = (s1_size > s2_size) ? shape1 : shape2; auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2; for (; i <= max_size; ++i) { - bh.all_vars.push_front(tvm::tir::Var()); + bh.all_vars.push_front(tvm::tir::Var("v", shape[max_size - 1].dtype())); bh.common_shape.push_front(shape[max_size - i]); vars.push_front(bh.all_vars[0]); } From bc0dc0a98a4b187289942ee19a71199c12055d5a Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Jul 2022 07:14:44 -0700 Subject: [PATCH 2/8] Cast and simplify both members of `Range` Only the `min` member was type-casted, which could lead to ranges with different types for `min` and `extent`. Move the casts to the argument of Simplify, so that they can be eliminated if they aren't needed. --- src/te/schedule/bound.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 87a175a34437..d8abffd6aa06 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -247,10 +247,11 @@ Map InferBound(const Schedule& sch) { } } for (auto it = ret.begin(); it != ret.end(); it++) { + DataType var_type = it->first->var.dtype(); it->second = Range::FromMinExtent( - analyzer.Simplify(it->second->min), - // The range associated with each itervar must have the same dtype as it - cast(it->first->var.dtype(), analyzer.Simplify(it->second->extent))); + // The range associated with each itervar must have the same dtype as the var + analyzer.Simplify(cast(var_type, it->second->min)), + analyzer.Simplify(cast(var_type, it->second->extent))); } return Map(ret.begin(), ret.end()); } From 72ab34c38067c885e8c0280bb957ce9b0916bda3 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Jul 2022 07:08:52 -0700 Subject: [PATCH 3/8] Type-check iv domain ranges, use cast only if needed in MakeLoopNest In some cases the domain ranges had the `min` and the `extent` values be of different types (e.g. [(int64)0, 32)). This is an error, and it can lead to compilation failures later on. Add a check for equal types here to catch this early. Also, only add the cast operation when the desired type differs from the current one to keep the expressions simpler. --- src/te/operation/op_utils.cc | 52 +++++++++++++++++++----------------- src/te/operation/op_utils.h | 12 ++++----- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index fd2a5c89f324..8644e75ff056 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -38,18 +38,16 @@ namespace te { using namespace arith; using namespace tir; -DataType LargerDataType(DataType a, DataType b) { return a.bits() > b.bits() ? a : b; } - -std::vector > MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop) { +std::vector> MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop) { auto leaf_iter_vars = stage->leaf_iter_vars; Stmt no_op = Evaluate(0); // create the loop nest - std::vector > nest; + std::vector> nest; nest.resize(leaf_iter_vars.size() + 1); std::unordered_map& value_map = *p_value_map; @@ -69,6 +67,10 @@ std::vector > MakeLoopNest(const Stage& stage, Range dom = dom_map.at(iv); + ICHECK(iv->var.dtype() == dom->min.dtype() && iv->var.dtype() == dom->extent.dtype()) + << "iter_var type " << iv->var.dtype() << " and domain types (min:" << dom->min.dtype() + << ", extent:" << dom->extent.dtype() << ") should all be the same"; + // This is a hack to ensure that the replacing expression has the same // dtype as the replacing expression. This happens when a thread/block // itervar is bound to another itervar. Because the thread/block itervar @@ -78,7 +80,9 @@ std::vector > MakeLoopNest(const Stage& stage, // bound to (in `bind`) but that would require inplace modification of the // itervar. // XXX: we will get integer overflow if the bound itervar is greater than int32::max. - auto promote_to_bound_dtype = [&iv](PrimExpr e) { return cast(iv->var.dtype(), e); }; + auto promote_to_iv_dtype = [type = iv->var.dtype()](PrimExpr e) { + return type != e.dtype() ? cast(type, e) : e; + }; // initialize the offset and loop_level Var var = bind_iv->var; @@ -125,15 +129,15 @@ std::vector > MakeLoopNest(const Stage& stage, } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { - nest[i + 1].emplace_back(LetStmt(var, promote_to_bound_dtype(dom->min), no_op)); - value_map[iv] = promote_to_bound_dtype(dom->min); + nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op)); + value_map[iv] = dom->min; } else if (is_zero(dom->min)) { - nest[i + 1].emplace_back(For(var, 0, promote_to_bound_dtype(dom->extent), kind, no_op)); - value_map[iv] = promote_to_bound_dtype(var); + nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op)); + value_map[iv] = promote_to_iv_dtype(var); } else { Var idx(bind_iv->var->name_hint + ".idx", iv->var.dtype()); - nest[i + 1].emplace_back(For(idx, 0, promote_to_bound_dtype(dom->extent), kind, no_op)); - PrimExpr new_value = promote_to_bound_dtype(dom->min + idx); + nest[i + 1].emplace_back(For(idx, 0, dom->extent, kind, no_op)); + PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; nest[i + 1].emplace_back(LetStmt(var, new_value, no_op)); } @@ -152,7 +156,7 @@ std::vector > MakeLoopNest(const Stage& stage, ICHECK(is_positive_const(dom->extent)); // annotate the extent of the IterVar nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); - value_map[iv] = promote_to_bound_dtype(var); + value_map[iv] = promote_to_iv_dtype(var); } else if (bind_iv->thread_tag == "pipeline") { // pipeline marker. ICHECK(is_zero(dom->min)); @@ -160,7 +164,7 @@ std::vector > MakeLoopNest(const Stage& stage, // annotate the extent of the IterVar nest[i + 1].emplace_back( AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); - value_map[iv] = promote_to_bound_dtype(dom->min); + value_map[iv] = dom->min; } else { // Always restrict threaded IterVar to starts from 0. ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at " @@ -168,28 +172,28 @@ std::vector > MakeLoopNest(const Stage& stage, // annotate the extent of the IterVar nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { - value_map[iv] = promote_to_bound_dtype(dom->min); + value_map[iv] = dom->min; } else if (stage->scope == "") { - value_map[iv] = promote_to_bound_dtype(var); + value_map[iv] = promote_to_iv_dtype(var); } else { runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag); runtime::StorageScope ss = runtime::StorageScope::Create(stage->scope); if (static_cast(ss.rank) <= ts.rank) { - value_map[iv] = promote_to_bound_dtype(var); + value_map[iv] = promote_to_iv_dtype(var); } else if (stage->scope == "warp" && ts.rank == 1) { // To determine whether a thread index is inside or outside a warp, we need // to know the thread extent. We leave a warning for now. if (ts.dim_index == 0) { - value_map[iv] = promote_to_bound_dtype(var); + value_map[iv] = promote_to_iv_dtype(var); } else { LOG(WARNING) << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. " << "TVM assumes only threadIdx.x indicates threads inside a warp, " << "while threadIdx.y and threadIdx.z indicates different warps."; - value_map[iv] = promote_to_bound_dtype(dom->min); + value_map[iv] = dom->min; } } else { - value_map[iv] = promote_to_bound_dtype(dom->min); + value_map[iv] = dom->min; } } } diff --git a/src/te/operation/op_utils.h b/src/te/operation/op_utils.h index 02f4a860a01d..f2e5782bf46f 100644 --- a/src/te/operation/op_utils.h +++ b/src/te/operation/op_utils.h @@ -51,12 +51,12 @@ using tir::MergeNest; * \param p_value_map The result value of each IterVar. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 */ -std::vector > MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop); +std::vector> MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop); /*! * \brief Create a nest of if checking the predicates. From ad06e147bacedbf863a56884d56715e9119dc156 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Jul 2022 07:18:40 -0700 Subject: [PATCH 4/8] Check that variable and substituted expression have same types Add a check to IRSubstitute to detect when the type of a variable and the type of the expression to replace it with have different types. --- src/tir/ir/stmt_functor.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 34bbb4b46ba4..b9d3d2dce846 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -26,7 +26,7 @@ #include -#include "./functor_common.h" +#include "functor_common.h" namespace tvm { namespace tir { @@ -647,6 +647,11 @@ class IRSubstitute : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto ret = vmap_(var); + if (ret.defined()) { + PrimExpr ret_ex = Downcast(ret.value()); + ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype() + << " -> " << ret_ex << ":" << ret_ex.dtype(); + } if (ret.defined()) return ret.value(); return std::move(var); } From ab71ffd5021156250ce976520569fff309aa4dfa Mon Sep 17 00:00:00 2001 From: Jiawei Liu Date: Fri, 8 Jul 2022 07:24:38 -0700 Subject: [PATCH 5/8] Add testcase --- tests/python/relay/test_op_level10.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index a2104e79762a..8c30ab27ce18 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -262,6 +262,23 @@ def test_broadcast_concat_shape_int64(executor_kind): tvm.testing.assert_allclose(op_res.numpy(), ref_res) +def test_broadcast_pool2d_shape_int64(executor_kind): + x_shape = (1, 3, 32, 32) + out_shape = (2, 3, 32, 32) + x = relay.var("data", shape=x_shape, dtype="float32") + broadcast_to = relay.broadcast_to(x, shape=relay.const([2, 3, 32, 32], dtype="int64")) + pool2d = relay.nn.max_pool2d(broadcast_to, pool_size=(3, 3), padding=(1, 1, 1, 1)) + sub = relay.subtract(broadcast_to, pool2d) + + f = relay.Function([x], sub) + x = np.ones(x_shape).astype("float32") + ref_res = np.zeros(out_shape).astype("float32") + + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(f)(x) + tvm.testing.assert_allclose(op_res.numpy(), ref_res) + + @tvm.testing.uses_gpu def test_broadcast_to_like(executor_kind): shape = (4, 1, 6) From c418032822a6125d35c3db34d71d37216bfa6db8 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Jul 2022 13:07:57 -0700 Subject: [PATCH 6/8] [TVMScript] Use void for lambda parameters, allow mismatch in Substitute When the script parser deals with lambdas, it creates Var objects for each parameter. Their actual types are not known at the time, and the properly typed variables are subtituted in the body later. Since the default dtype of a Var is "int32", this could lead to a type mismatch in Substitute. To deal with this scenario, use "void" for newly created Vars in the parser, and add an exception to Substitute to allow replacing void Vars with expressions of any type. --- python/tvm/script/parser.py | 5 ++++- src/tir/ir/stmt_functor.cc | 12 ++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index e4bdd1206506..0932e717bbec 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -526,7 +526,10 @@ def transform_Lambda(self, node): # add parameters of the lambda arg_vars = [] for arg in node.params: - arg_var = tvm.te.var(arg.name) + # Use "void" for dtype here. The actual type is not yet known and will be + # determined later. Using void type will allow IRSubstitute to do the + # replacement without flagging a type-mismatch error. + arg_var = tvm.te.var(arg.name, dtype="") arg_vars.append(arg_var) self.context.update_symbol(arg.name, arg_var, node) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index b9d3d2dce846..c0abf953eec2 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -648,11 +648,15 @@ class IRSubstitute : public StmtExprMutator { Var var = GetRef(op); auto ret = vmap_(var); if (ret.defined()) { - PrimExpr ret_ex = Downcast(ret.value()); - ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype() - << " -> " << ret_ex << ":" << ret_ex.dtype(); + // Allow substitution of void variables with any expression. The TVM script parser + // uses void variables for lambda parameters (since exact types are not known yet). + if (!var.dtype().is_void()) { + PrimExpr ret_ex = Downcast(ret.value()); + ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype() + << " -> " << ret_ex << ":" << ret_ex.dtype(); + } + return ret.value(); } - if (ret.defined()) return ret.value(); return std::move(var); } From 15326ede439e855c024c011ddb9977b391e3c879 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Jul 2022 14:24:28 -0700 Subject: [PATCH 7/8] Fix type error in test_reduce_combiner_simplify --- tests/python/unittest/test_arith_canonical_simplify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 74c8bcb5fddf..81a163d0d431 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -161,7 +161,7 @@ def test_reduce_combiner_simplify(): ) sum_and_prod = comm_reducer( lambda x, y: (x[0] + y[0], x[1] * y[1]), - lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t0) - tvm.tir.const(4, t0)), + lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t1) - tvm.tir.const(4, t1)), ) some_reducer1 = comm_reducer( lambda x, y: ( From 305b82d8cf10f87f70783a28227cab86dee7f18c Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 8 Jul 2022 18:48:10 -0500 Subject: [PATCH 8/8] Restart CI