Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOT] Remove lookup parameter function in AOT #7988

Merged
merged 3 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ TVM_DLL const Op& tvm_struct_get();
*/
TVM_DLL const Op& tvm_struct_set();

/*!
* \brief See pseudo code
* Type lookup_param(String param_name) {
* return __tvm_param__param_name;
* }
*/
TVM_DLL const Op& lookup_param();

/*!
* \brief See pesudo code
*
Expand Down
25 changes: 4 additions & 21 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,40 +152,23 @@ class AOTExecutorCodegen : public ExprVisitor {
* \return Variable that represents the DLTensor associated with the parameters
*/
tir::Var PackParam(Expr expr) {
// TODO(giuseros): Using call_extern to call into lookup_linked_param. This is because the
// builtin::ret is not supported yet in the c target. Once return is supported we can use
// tvm_call_packed_lowered().
int param_sid = param_storage_ids_[params_by_expr_[expr]];
auto lookup_linked_param_fn = tir::StringImm(::tvm::runtime::symbol::tvm_lookup_linked_param);
auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle());

// Compose the lookup_call using a local stack
Array<tir::Stmt> lookup_call;
auto param_var = te::Var(MakeString("param_", param_sid, "_value"), DataType::Handle());
auto ret_var = te::Var("ret_value", DataType::Handle());
auto ret_code = te::Var("ret_value", DataType::Handle());

lookup_call.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_var, 0, tir::builtin::kTVMValueContent, ConstInt32(param_sid)})));
lookup_call.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tir::builtin::call_extern(),
{lookup_linked_param_fn, param_var, 0, 0, ret_var, ret_code, 0})));
auto ret_var_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{ret_var, 0, tir::builtin::kTVMValueContent});

// Set the param to the value returned by lookup_call
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});

tvm::PrimExpr set_param_array =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_array, 0, tir::builtin::kArrData, ret_var_handle});
{param_array, 0, tir::builtin::kArrData, param_handle});
lookup_call.push_back(tir::Evaluate(set_param_array));

tir::Stmt lookup_body = tir::SeqStmt(lookup_call);

// Allocate the DLTensors on the stack
lookup_body = tir::LetStmt(param_var, StackAlloca("arg_value", 1), lookup_body);
lookup_body = tir::LetStmt(ret_var, StackAlloca("arg_value", 1), lookup_body);
lookup_body = tir::LetStmt(ret_code, StackAlloca("arg_value", 1), lookup_body);
lookup_body = tir::LetStmt(param_array, StackAlloca("arg_value", 1), lookup_body);
stmts_.push_back(lookup_body);
return param_array;
Expand Down
5 changes: 5 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
os << " != ";
this->PrintExpr(op->args[0], os);
os << ")";
} else if (op->op.same_as(builtin::lookup_param())) {
ICHECK_EQ(op->args.size(), 1);
const StringImmNode* str = op->args[0].as<StringImmNode>();
ICHECK(str != nullptr);
os << "__tvm_param__" << str->value;
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
Expand Down
39 changes: 23 additions & 16 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
CodeGenC::AddFunction(f);
}

void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
PrintFuncPrefix();
stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param),
tvm::runtime::symbol::tvm_lookup_linked_param)
<< "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param;
stream << " switch (((int64_t*) args)[0]) {\n"
<< " default:\n"
<< " out_ret_tcode[0] = " << kTVMNullptr << ";\n"
<< " return 0;\n";

function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
void CodeGenCHost::DeclareParameters(Map<String, LinkedParam> params) {
for (auto kv : params) {
decl_stream << "\n"
<< "#ifdef __cplusplus\n"
Expand All @@ -93,6 +80,24 @@ void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
<< "#ifdef __cplusplus\n"
<< "} // extern \"C\"\n"
<< "#endif\n";
}
}

void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) {
PrintFuncPrefix();
stream << " " << tvm::runtime::symbol::tvm_lookup_linked_param
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
ICHECK_EQ(GetUniqueName(tvm::runtime::symbol::tvm_lookup_linked_param),
tvm::runtime::symbol::tvm_lookup_linked_param)
<< "builtin PackedFunc name already taken: " << tvm::runtime::symbol::tvm_lookup_linked_param;
stream << " switch (((int64_t*) args)[0]) {\n"
<< " default:\n"
<< " out_ret_tcode[0] = " << kTVMNullptr << ";\n"
<< " return 0;\n";

function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
for (auto kv : params) {
stream << " case " << kv.second->id << ":\n"
<< " ((uint64_t*)out_ret_value)[0] = (uint64_t) (uintptr_t) "
<< ::tvm::runtime::symbol::tvm_param_prefix << kv.first << ";\n"
Expand Down Expand Up @@ -398,12 +403,14 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
cg.AddFunction(f);
}

if (could_have_linked_params) {
if (could_have_linked_params && !aot_executor_fn.defined()) {
ICHECK(found_linked_params) << "-link-params given but none found";
cg.DeclareParameters(linked_params);
cg.LinkParameters(linked_params);
}

if (aot_executor_fn.defined()) {
if (could_have_linked_params && aot_executor_fn.defined()) {
cg.DeclareParameters(linked_params);
cg.AddFunction(aot_executor_fn);
}

Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class CodeGenCHost final : public CodeGenC {
void AddFunction(const PrimFunc& f);

/*! \brief Add linked parameters, if they are present. */
void DeclareParameters(Map<String, LinkedParam> params);
void LinkParameters(Map<String, LinkedParam> params);

void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));

TIR_DEFINE_BUILTIN_FUNC(lookup_param)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kUpdateState));

TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
Expand Down