Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several type mismatch fixes and checks #12041

Merged
merged 8 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions include/tvm/topi/detail/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,53 @@ struct BroadcastHelper {
std::deque<tvm::tir::Var> vars2;
};

static inline DataType CommonType(DataType type1, DataType type2) {
ICHECK(type1.is_scalar() && type2.is_scalar());
ICHECK(type1.code() == type2.code());
return DataType(type1.code(), std::max(type1.bits(), type2.bits()), /*lanes=*/1);
}

inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
const tvm::Array<tvm::PrimExpr>& shape2) {
BroadcastHelper bh;
int s1_size = shape1.size();
int s2_size = shape2.size();
tvm::PrimExpr one(1);
int i;

auto cast_if_needed = [](DataType to_type, PrimExpr expr) {
return to_type != expr.dtype() ? cast(to_type, expr) : expr;
};

for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
// TODO(@icemelon9): Need to revisit this part
const IntImmNode* static_size1 = shape1[s1_size - i].as<IntImmNode>();
const IntImmNode* static_size2 = shape2[s2_size - i].as<IntImmNode>();
bh.all_vars.push_front(tvm::tir::Var());
DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype());

bh.all_vars.push_front(tvm::tir::Var("dim", common_type));
if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (topi::detail::EqualCheck(one, shape1[s1_size - i])) {
ICHECK(!topi::detail::EqualCheck(one, shape2[s2_size - i]));
bh.common_shape.push_front(shape2[s2_size - i]);
bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
bh.vars2.push_front(bh.all_vars[0]);
} else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
} else if (!static_size1 && !static_size2) {
bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i]));
bh.common_shape.push_front(
cast_if_needed(common_type, max(shape1[s1_size - i], shape2[s2_size - i])));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else if (!static_size1) {
bh.common_shape.push_front(shape2[s2_size - i]);
bh.common_shape.push_front(cast_if_needed(common_type, shape2[s2_size - i]));
bh.vars2.push_front(bh.all_vars[0]);
bh.vars1.push_front(bh.all_vars[0]);
} else if (!static_size2) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i]));
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else {
Expand All @@ -89,7 +103,7 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::PrimExpr>& shape1,
auto& shape = (s1_size > s2_size) ? shape1 : shape2;
auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2;
for (; i <= max_size; ++i) {
bh.all_vars.push_front(tvm::tir::Var());
bh.all_vars.push_front(tvm::tir::Var("v", shape[max_size - 1].dtype()));
bh.common_shape.push_front(shape[max_size - i]);
vars.push_front(bh.all_vars[0]);
}
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,10 @@ def transform_Lambda(self, node):
# add parameters of the lambda
arg_vars = []
for arg in node.params:
arg_var = tvm.te.var(arg.name)
# Use "void" for dtype here. The actual type is not yet known and will be
# determined later. Using void type will allow IRSubstitute to do the
# replacement without flagging a type-mismatch error.
arg_var = tvm.te.var(arg.name, dtype="")
arg_vars.append(arg_var)
self.context.update_symbol(arg.name, arg_var, node)

Expand Down
52 changes: 28 additions & 24 deletions src/te/operation/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,16 @@ namespace te {
using namespace arith;
using namespace tir;

DataType LargerDataType(DataType a, DataType b) { return a.bits() > b.bits() ? a : b; }

std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, PrimExpr>* p_value_map,
bool debug_keep_trivial_loop) {
std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, PrimExpr>* p_value_map,
bool debug_keep_trivial_loop) {
auto leaf_iter_vars = stage->leaf_iter_vars;
Stmt no_op = Evaluate(0);
// create the loop nest
std::vector<std::vector<Stmt> > nest;
std::vector<std::vector<Stmt>> nest;
nest.resize(leaf_iter_vars.size() + 1);
std::unordered_map<IterVar, PrimExpr>& value_map = *p_value_map;

Expand All @@ -69,6 +67,10 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,

Range dom = dom_map.at(iv);

ICHECK(iv->var.dtype() == dom->min.dtype() && iv->var.dtype() == dom->extent.dtype())
<< "iter_var type " << iv->var.dtype() << " and domain types (min:" << dom->min.dtype()
<< ", extent:" << dom->extent.dtype() << ") should all be the same";

// This is a hack to ensure that the replacing expression has the same
// dtype as the replacing expression. This happens when a thread/block
// itervar is bound to another itervar. Because the thread/block itervar
Expand All @@ -78,7 +80,9 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
// bound to (in `bind`) but that would require inplace modification of the
// itervar.
// XXX: we will get integer overflow if the bound itervar is greater than int32::max.
auto promote_to_bound_dtype = [&iv](PrimExpr e) { return cast(iv->var.dtype(), e); };
auto promote_to_iv_dtype = [type = iv->var.dtype()](PrimExpr e) {
return type != e.dtype() ? cast(type, e) : e;
};

// initialize the offset and loop_level
Var var = bind_iv->var;
Expand Down Expand Up @@ -125,15 +129,15 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
}
}
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
nest[i + 1].emplace_back(LetStmt(var, promote_to_bound_dtype(dom->min), no_op));
value_map[iv] = promote_to_bound_dtype(dom->min);
nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op));
value_map[iv] = dom->min;
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(For(var, 0, promote_to_bound_dtype(dom->extent), kind, no_op));
value_map[iv] = promote_to_bound_dtype(var);
nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op));
value_map[iv] = promote_to_iv_dtype(var);
} else {
Var idx(bind_iv->var->name_hint + ".idx", iv->var.dtype());
nest[i + 1].emplace_back(For(idx, 0, promote_to_bound_dtype(dom->extent), kind, no_op));
PrimExpr new_value = promote_to_bound_dtype(dom->min + idx);
nest[i + 1].emplace_back(For(idx, 0, dom->extent, kind, no_op));
PrimExpr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(LetStmt(var, new_value, no_op));
}
Expand All @@ -152,44 +156,44 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
ICHECK(is_positive_const(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op));
value_map[iv] = promote_to_bound_dtype(var);
value_map[iv] = promote_to_iv_dtype(var);
} else if (bind_iv->thread_tag == "pipeline") {
// pipeline marker.
ICHECK(is_zero(dom->min));
ICHECK(is_one(dom->extent));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op));
value_map[iv] = promote_to_bound_dtype(dom->min);
value_map[iv] = dom->min;
} else {
// Always restrict threaded IterVar to starts from 0.
ICHECK(is_zero(dom->min)) << "Itervar " << iv << " must start at zero, but it starts at "
<< dom->min;
// annotate the extent of the IterVar
nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op));
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = promote_to_bound_dtype(dom->min);
value_map[iv] = dom->min;
} else if (stage->scope == "") {
value_map[iv] = promote_to_bound_dtype(var);
value_map[iv] = promote_to_iv_dtype(var);
} else {
runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag);
runtime::StorageScope ss = runtime::StorageScope::Create(stage->scope);
if (static_cast<int>(ss.rank) <= ts.rank) {
value_map[iv] = promote_to_bound_dtype(var);
value_map[iv] = promote_to_iv_dtype(var);
} else if (stage->scope == "warp" && ts.rank == 1) {
// To determine whether a thread index is inside or outside a warp, we need
// to know the thread extent. We leave a warning for now.
if (ts.dim_index == 0) {
value_map[iv] = promote_to_bound_dtype(var);
value_map[iv] = promote_to_iv_dtype(var);
} else {
LOG(WARNING)
<< "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. "
<< "TVM assumes only threadIdx.x indicates threads inside a warp, "
<< "while threadIdx.y and threadIdx.z indicates different warps.";
value_map[iv] = promote_to_bound_dtype(dom->min);
value_map[iv] = dom->min;
}
} else {
value_map[iv] = promote_to_bound_dtype(dom->min);
value_map[iv] = dom->min;
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/te/operation/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ using tir::MergeNest;
* \param p_value_map The result value of each IterVar.
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
*/
std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, PrimExpr>* p_value_map,
bool debug_keep_trivial_loop);
std::vector<std::vector<Stmt>> MakeLoopNest(const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
size_t begin_iter_pos, bool new_loop_var,
const std::unordered_set<IterVar>& skip_iter,
std::unordered_map<IterVar, PrimExpr>* p_value_map,
bool debug_keep_trivial_loop);

/*!
* \brief Create a nest of if checking the predicates.
Expand Down
7 changes: 4 additions & 3 deletions src/te/schedule/bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,11 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
}
}
for (auto it = ret.begin(); it != ret.end(); it++) {
DataType var_type = it->first->var.dtype();
it->second = Range::FromMinExtent(
analyzer.Simplify(it->second->min),
// The range associated with each itervar must have the same dtype as it
cast(it->first->var.dtype(), analyzer.Simplify(it->second->extent)));
// The range associated with each itervar must have the same dtype as the var
analyzer.Simplify(cast(var_type, it->second->min)),
analyzer.Simplify(cast(var_type, it->second->extent)));
}
return Map<IterVar, Range>(ret.begin(), ret.end());
}
Expand Down
13 changes: 11 additions & 2 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include <functional>

#include "./functor_common.h"
#include "functor_common.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -647,7 +647,16 @@ class IRSubstitute : public StmtExprMutator {
PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto ret = vmap_(var);
if (ret.defined()) return ret.value();
if (ret.defined()) {
// Allow substitution of void variables with any expression. The TVM script parser
// uses void variables for lambda parameters (since exact types are not known yet).
if (!var.dtype().is_void()) {
PrimExpr ret_ex = Downcast<PrimExpr>(ret.value());
ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype()
<< " -> " << ret_ex << ":" << ret_ex.dtype();
}
return ret.value();
}
return std::move(var);
}

Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,23 @@ def test_broadcast_concat_shape_int64(executor_kind):
tvm.testing.assert_allclose(op_res.numpy(), ref_res)


def test_broadcast_pool2d_shape_int64(executor_kind):
x_shape = (1, 3, 32, 32)
out_shape = (2, 3, 32, 32)
x = relay.var("data", shape=x_shape, dtype="float32")
broadcast_to = relay.broadcast_to(x, shape=relay.const([2, 3, 32, 32], dtype="int64"))
pool2d = relay.nn.max_pool2d(broadcast_to, pool_size=(3, 3), padding=(1, 1, 1, 1))
sub = relay.subtract(broadcast_to, pool2d)

f = relay.Function([x], sub)
x = np.ones(x_shape).astype("float32")
ref_res = np.zeros(out_shape).astype("float32")

for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(f)(x)
tvm.testing.assert_allclose(op_res.numpy(), ref_res)


@tvm.testing.uses_gpu
def test_broadcast_to_like(executor_kind):
shape = (4, 1, 6)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_reduce_combiner_simplify():
)
sum_and_prod = comm_reducer(
lambda x, y: (x[0] + y[0], x[1] * y[1]),
lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t0) - tvm.tir.const(4, t0)),
lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t1) - tvm.tir.const(4, t1)),
)
some_reducer1 = comm_reducer(
lambda x, y: (
Expand Down