diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 5608c21904e7e..852c7d0d8a985 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -278,6 +278,12 @@ class IRModuleNode : public Object { */ TVM_DLL void Update(const IRModule& other); + /*! + * \brief Create a shallow copy of this IRModule. + * \returns The shallow copy of the IRModule. + */ + TVM_DLL IRModule ShallowCopy(); + /*! * \brief Import Relay code from the file at path. * \param path The path of the Relay code to import. @@ -418,6 +424,13 @@ class IRModule : public ObjectRef { */ TVM_DLL static IRModule FromText(const String& text, const String& source_path); + /*! + * \brief Create a shallow copy of an IRModule. + * \param mod The module to copy. + * \return The copied module. + */ + IRModule ShallowCopyIRModule(IRModule mod); + /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; diff --git a/src/ir/module.cc b/src/ir/module.cc index 97f2d546c5c83..15c441d61a237 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -365,6 +365,11 @@ void IRModuleNode::Update(const IRModule& mod) { } } +IRModule IRModuleNode::ShallowCopy() { + return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map, + this->attrs); +} + std::pair IRModule::FromExprInContext( const RelayExpr& expr, const tvm::Map& global_funcs, const tvm::Map& type_definitions, diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 4c372d78f5fe6..344d1cae78237 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -133,8 +133,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; - IRModule updated_mod = - IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map, mod->attrs); + IRModule updated_mod = mod->ShallowCopy(); std::vector > updates; for (const auto& it : updated_mod->functions) { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index b339499b4d997..019659b3166ea 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -30,6 +30,7 @@ */ #include +#include #include #include #include @@ -509,8 +510,10 @@ class NameMangleExtFuncs : public MixedModeMutator { // Walk the tree and mangle the functions. Then replace compiler functions // with mangled functions in the module - IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports(), - module_->source_map, module_->attrs); + IRModule new_module = module_->ShallowCopy(); + new_module->functions = {}; + // IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports(), + // module_->source_map, module_->attrs); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 22e17e6d9ee9f..04158aa02f64d 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -52,8 +52,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { DLOG(INFO) << "ToBBlock:" << std::endl << mod; // Create a new module by shallow copy. - auto mod_ = - IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map, mod->attrs); + auto mod_ = mod->ShallowCopy(); tvm::Map updates; auto funcs = mod_->functions; diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 8022172f80f9e..6c2371716b167 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -826,8 +826,7 @@ Pass InferType() { [=](IRModule mod, const PassContext& pass_ctx) { DLOG(INFO) << "tvm::relay::transform::InferType"; // Execute the pass function and return a new module. - IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports(), - mod->source_map, mod->attrs); + IRModule updated_mod = mod->ShallowCopy(); pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod);