Skip to content

Commit

Permalink
Add ShallowCopy to IRmodule
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies committed Sep 7, 2021
1 parent e0420fe commit 2c8d8c2
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 8 deletions.
13 changes: 13 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ class IRModuleNode : public Object {
*/
TVM_DLL void Update(const IRModule& other);

/*!
* \brief Create a shallow copy of this IRModule.
* \returns The shallow copy of the IRModule.
*/
TVM_DLL IRModule ShallowCopy();

/*!
* \brief Import Relay code from the file at path.
* \param path The path of the Relay code to import.
Expand Down Expand Up @@ -418,6 +424,13 @@ class IRModule : public ObjectRef {
*/
TVM_DLL static IRModule FromText(const String& text, const String& source_path);

/*!
* \brief Create a shallow copy of an IRModule.
* \param mod The module to copy.
* \return The copied module.
*/
IRModule ShallowCopyIRModule(IRModule mod);

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;

Expand Down
5 changes: 5 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ void IRModuleNode::Update(const IRModule& mod) {
}
}

IRModule IRModuleNode::ShallowCopy() {
return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
this->attrs);
}

std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
const RelayExpr& expr, const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
Expand Down
3 changes: 1 addition & 2 deletions src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx)
DLOG(INFO) << "Executing function pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level;

IRModule updated_mod =
IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map, mod->attrs);
IRModule updated_mod = mod->ShallowCopy();

std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
Expand Down
7 changes: 5 additions & 2 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
*/

#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -509,8 +510,10 @@ class NameMangleExtFuncs : public MixedModeMutator {

// Walk the tree and mangle the functions. Then replace compiler functions
// with mangled functions in the module
IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports(),
module_->source_map, module_->attrs);
IRModule new_module = module_->ShallowCopy();
new_module->functions = {};
// IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports(),
// module_->source_map, module_->attrs);
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/to_basic_block_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) {
DLOG(INFO) << "ToBBlock:" << std::endl << mod;

// Create a new module by shallow copy.
auto mod_ =
IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map, mod->attrs);
auto mod_ = mod->ShallowCopy();

tvm::Map<GlobalVar, Function> updates;
auto funcs = mod_->functions;
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,7 @@ Pass InferType() {
[=](IRModule mod, const PassContext& pass_ctx) {
DLOG(INFO) << "tvm::relay::transform::InferType";
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports(),
mod->source_map, mod->attrs);
IRModule updated_mod = mod->ShallowCopy();

pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod);

Expand Down

0 comments on commit 2c8d8c2

Please sign in to comment.