Skip to content

Commit

Permalink
rm convertToSSA API,test=huawei_ascend_npu test=nvidia_tensorrt test=…
Browse files Browse the repository at this point in the history
…verisilicon_timvx (#8988) (#9233)
  • Loading branch information
weishengying authored Jul 9, 2022
1 parent 8f379b4 commit 73750d9
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 23 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ lite_option(LITE_WITH_XCODE "when debug in xcode, its ON."
lite_option(LITE_WITH_ARM82_FP16 "when compile with arm v8.2 fp16, it's ON." OFF)
lite_option(LITE_WITH_ARM82_INT8_SDOT "when compile with arm v8.2 int8, it's ON." OFF)
lite_option(LITE_WITH_CODE_META_INFO "include git version in the header file." ON)
# whether convert input model which is not a DAG to SSA graph
lite_option(WITH_CONVERT_TO_SSA "whether convert input model which is not a DAG to SSA graph" ON)

# Thirdparty
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
Expand Down
3 changes: 0 additions & 3 deletions cmake/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,3 @@ if (LITE_WITH_M1)
add_definitions("-DLITE_WITH_M1")
endif(LITE_WITH_M1)

if (WITH_CONVERT_TO_SSA STREQUAL ON)
add_definitions("-DWITH_CONVERT_TO_SSA")
endif(WITH_CONVERT_TO_SSA)
181 changes: 173 additions & 8 deletions lite/core/optimizer/mir/type_target_cast_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {

// record the copied node.
std::map<std::string, Node*> copied_nodes;
// record the origin node.
std::map<std::string, Node*> input_nodes;
std::vector<std::string> skip_ops = {
"while", "conditional_block", "write_back"};

Expand All @@ -48,8 +50,14 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (!node->IsStmt() || iter != skip_ops.end()) continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
if (!input_nodes.count(in->AsArg().name))
input_nodes[in->AsArg().name] = in;
ComplementInputs(graph.get(), node, in, &copied_nodes);
}
auto outlinks = node->outlinks;
for (auto* out : outlinks) {
ComplementOutputs(graph.get(), node, out, &input_nodes);
}
}
}

Expand Down Expand Up @@ -78,17 +86,174 @@ void TypeTargetTransformPass::ComplementInputs(
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type,
*decl_arg_type,
in,
graph,
inst_node,
copied_nodes,
valid_places_);
AddInputIoCopyInst(*in->AsArg().type,
*decl_arg_type,
in,
graph,
inst_node,
copied_nodes,
valid_places_);
}
}

void TypeTargetTransformPass::AddOutputIoCopyInst(
const Type& from,
const Type& to,
Node* out,
SSAGraph* graph,
Node* inst_node,
const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// inst -> out node(new_name) -> io_copy_op -> new_var_node(out->AsArg().name)
// So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(out->IsArg());
auto new_name = string_format("%s/target_trans", out->AsArg().name.c_str());
auto* new_var_node = graph->NewArgumentNode(out->AsArg().name);

// Set the place for new var node, the target should be equal to to.target()
// The precision and layout should be equal to from.precision(), from.layout()
bool is_tensor = from.IsTensor();
if (!is_tensor) {
CHECK(from.IsTensorList()) << "only support tensor or tensor_array.";
}
if (is_tensor) {
new_var_node->AsArg().type =
LiteType::GetTensorTy(to.target(), from.precision(), from.layout());
} else {
new_var_node->AsArg().type =
LiteType::GetTensorListTy(to.target(), from.precision(), from.layout());
}
auto* io_copy_inst = graph->NewInstructNode();
std::string io_copy_type = "io_copy";
// create Op and kernels.
auto io_copy_op = LiteOpRegistry::Global().Create(io_copy_type);
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsStmt().op()->scope()->Var(new_name);

// Create IoCopy Instruction.
cpp::OpDesc op_desc;
op_desc.SetType(io_copy_type);
if (is_tensor) {
op_desc.SetInput("Input", {new_name});
op_desc.SetOutput("Out", {out->AsArg().name});
} else {
op_desc.SetInput("InputArray", {new_name});
op_desc.SetOutput("OutArray", {out->AsArg().name});
}
io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
bool is_found = false;
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
for (auto& kernel : kernels) {
const Type* in_arg_ty = nullptr;
const Type* out_arg_ty = nullptr;
if (is_tensor) {
in_arg_ty = kernel->GetInputDeclType("Input");
out_arg_ty = kernel->GetOutputDeclType("Out");
} else {
in_arg_ty = kernel->GetInputDeclType("InputArray");
out_arg_ty = kernel->GetOutputDeclType("OutArray");
}

VLOG(4) << "------ kernel info -------";
VLOG(4) << "*in_arg_ty(io_copy kernel input):" << *in_arg_ty;
VLOG(4) << "from(last kernel output):" << from;
VLOG(4) << "out_arg_ty(io_copy kernel output):" << *out_arg_ty;
VLOG(4) << "to:" << to << "\n";

if (TypeCompatible(*in_arg_ty, from) &&
TargetCompatibleTo(*out_arg_ty, to)) {
VLOG(4) << "picked";
is_found = true;
}

if (is_found) {
selected_kernels.emplace_back(std::move(kernel));
// we pick the kernel
io_copy_inst->AsStmt(
io_copy_type, std::move(selected_kernels), io_copy_op);
break;
}
VLOG(4) << "not picked";
}

CHECK(is_found) << "Can't find a io_copy kernel for io_copy op: " << from
<< ":" << inst_node->AsStmt().op_info()->Type() << " -> "
<< to << ":" << out->AsArg().name;
// Add new link, inst -> var -> io_copy_op -> new_var_node
DirectedLink(out, io_copy_inst);
DirectedLink(io_copy_inst, new_var_node);

// Update the original instruction OpDesc.
// Update its output var name to the io_copy_output_name
auto* inst_node_op_desc = inst_node->AsStmt().op()->mutable_op_info();
for (auto& op_output : *inst_node_op_desc->mutable_outputs()) {
for (auto& var_name : op_output.second)
if (var_name == out->AsArg().name) var_name = new_name;
}
// Update the input name of Ops whose input var is out var node
for (auto& op : out->outlinks) {
if (!op->IsStmt()) continue;
auto* op_desc = op->AsStmt().op()->mutable_op_info();
for (auto& op_input : *op_desc->mutable_inputs())
for (auto& var_name : op_input.second)
if (var_name == out->AsArg().name) var_name = new_name;
}
// reset opdesc and update kernel information
out->AsArg().name = new_name;
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
inst_node->AsStmt().ResetOp(update_op_info, graph->valid_places());
inst_node->AsStmt().kernels().clear();
inst_node->AsStmt().kernels().emplace_back(
std::move(original_selected_kernel));

for (auto& kernel : inst_node->AsStmt().kernels()) {
VLOG(4) << "kernel info: " << kernel->name();
inst_node->AsStmt().op()->AttachKernel(kernel.get());
}

graph->CheckValid();
}

void TypeTargetTransformPass::ComplementOutputs(
SSAGraph* graph,
Node* inst_node,
Node* out,
std::map<std::string, Node*>* input_nodes) {
// If this output is out of date.
if (inst_node->outlinks.end() ==
std::find(inst_node->outlinks.begin(), inst_node->outlinks.end(), out))
return;

CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsStmt();
VLOG(3) << "found Target tensor: " << out->AsArg().name;
CHECK(out->IsRoleSet());
CHECK(out->IsArg());
CHECK(out->AsArg().type);
if (input_nodes->count(out->AsArg().name)) {
if (!TargetCompatibleTo(
*out->AsArg().type,
*input_nodes->at(out->AsArg().name)->AsArg().type)) {
VLOG(3) << "found Output Target unmatched tensor: " << out->AsArg().name
<< " for kernel " << inst.op()->DebugString() << " "
<< *out->AsArg().type << " -> "
<< *(input_nodes->at(out->AsArg().name))->AsArg().type;
AddOutputIoCopyInst(*out->AsArg().type,
*input_nodes->at(out->AsArg().name)->AsArg().type,
out,
graph,
inst_node,
valid_places_);
}
}
}

void TypeTargetTransformPass::AddIoCopyInst(
void TypeTargetTransformPass::AddInputIoCopyInst(
const Type& from,
const Type& to,
Node* in,
Expand Down
26 changes: 19 additions & 7 deletions lite/core/optimizer/mir/type_target_cast_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,25 @@ class TypeTargetTransformPass : public ProgramPass {
Node* in,
std::map<std::string, Node*>* copied_nodes);

void AddIoCopyInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places);
void ComplementOutputs(SSAGraph* graph,
Node* inst_node,
Node* out,
std::map<std::string, Node*>* input_nodes);

void AddInputIoCopyInst(const Type& from,
const Type& to,
Node* in,
SSAGraph* graph,
Node* inst_node,
std::map<std::string, Node*>* copied_nodes,
const std::vector<Place>& valid_places);

void AddOutputIoCopyInst(const Type& from,
const Type& to,
Node* out,
SSAGraph* graph,
Node* inst_node,
const std::vector<Place>& valid_places);

void SetValidPlaces(const std::vector<Place>& valid_places);

Expand Down
3 changes: 0 additions & 3 deletions lite/model_parser/model_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ void LoadModelPb(const std::string &model_dir,
pb::ProgramDesc pb_prog(&pb_proto_prog);
// Transform to cpp::ProgramDesc
TransformProgramDescAnyToCpp(pb_prog, cpp_prog);
#ifdef WITH_CONVERT_TO_SSA
general::ssa::ConvertToSSA(cpp_prog);
#endif

// Load params data from file.
// NOTE: Only main block be used now.
Expand Down

0 comments on commit 73750d9

Please sign in to comment.