Skip to content

Commit

Permalink
[FIX] Add check for repeated function parameters (apache#21)
Browse files Browse the repository at this point in the history
* Add check for repeated function parameters

* Modify output format and unit test
  • Loading branch information
Ubospica authored and MasterJH5574 committed Nov 19, 2022
1 parent 91d9547 commit bd21748
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

#include <unordered_set>

#include "../../printer/text_printer.h"

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -124,6 +126,7 @@ class WellFormedChecker : public relax::ExprVisitor {
void VisitExpr_(const FunctionNode* op) {
// save the var_set_ for local function
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set_ = var_set_;
Function func = GetRef<Function>(op);
for (Var param : op->params) {
// register symbolic var defined in the shape annotation of function params
if (param->shape_) {
Expand All @@ -142,6 +145,15 @@ class WellFormedChecker : public relax::ExprVisitor {
}

this->VisitVarDef(param);

if (param_var_func_map_.count(param) == 1) {
Malformed(Diagnostic::Error(param->span)
<< "Relax variable " << param->name_hint()
<< " is repeatedly used as parameters in function:\n"
<< AsRelaxScript(param_var_func_map_[param], false)
<< "\nand function:\n" << AsRelaxScript(func, false));
}
param_var_func_map_.insert({param, func});
}
this->VisitBody(op->body);
var_set_ = previous_var_set_;
Expand Down Expand Up @@ -286,6 +298,8 @@ class WellFormedChecker : public relax::ExprVisitor {
std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> global_var_set_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> dataflow_var_set_;
std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_;

PrimExprVisitor prim_expr_visitor_;
};

Expand Down
3 changes: 3 additions & 0 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. "
"Did you forget to apply LambdaLift pass?";

// var_register_map_ is local in function scope
var_register_map_.clear();

Array<String> param_names;
for (Var param : func_node->params) {
param_names.push_back(param->name_hint());
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relax/test_transform_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ def test_dataflow_var():
assert not rx.analysis.well_formed(mod)


def test_param_var():
v0 = rx.Var("v0", [m, n], type_anno)
v1 = rx.Var("v1", [m, n], type_anno)
v2 = rx.Var("v2", [m, n], type_anno)
bb = rx.BlockBuilder()
with bb.function("func1", [v0, v1]):
gv0 = bb.emit(rx.op.add(v0, v1))
bb.emit_func_output(gv0)
with bb.function("func2", [v0, v2]):
gv0 = bb.emit(rx.op.add(v0, v2))
bb.emit_func_output(gv0)
mod = bb.get()
assert not rx.analysis.well_formed(mod)


def test_global_var():
# Error: GlobalVar GlobalVar0 is not defined
gv0 = rx.Var("gv0", [m, n], type_anno)
Expand Down

0 comments on commit bd21748

Please sign in to comment.