Skip to content

Commit

Permalink
[XPU] pad3d and memory pass (#8213)
Browse files Browse the repository at this point in the history
  • Loading branch information
laiou authored Jan 17, 2022
1 parent d3a994e commit 54525ab
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 82 deletions.
23 changes: 16 additions & 7 deletions lite/core/optimizer/mir/fusion/inplace_fuser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,31 @@
#include "lite/core/optimizer/mir/fusion/inplace_fuser.h"
#include <memory>
#include <vector>
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"

namespace paddle {
namespace lite {
namespace mir {
namespace fusion {

void InplaceFuser::BuildPattern() { OpNode("inplace", type_); }
void InplaceFuser::BuildPattern() {
auto* input = VarNode("input")
->assert_is_op_input(type_, "X")
->assert_only_one_output()
->AsInput();

auto* op_node = OpNode("inplace", type_)->assert_is_op(type_);

auto* output = VarNode("output")
->assert_is_op_output(type_, "Out")
->assert_only_one_output()
->AsOutput();

*input >> *op_node >> *output;
}

void InplaceFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto out_var_nodes = matched.at("inplace")->outlinks;
bool inplace = true;
for (auto& out_var_node : out_var_nodes) {
if (out_var_node->outlinks.size() > 1) {
inplace = false;
}
}
auto* stmt = matched.at("inplace")->stmt();
auto op = stmt->op();
cpp::OpDesc* op_desc = op->mutable_op_info();
Expand Down
Loading

0 comments on commit 54525ab

Please sign in to comment.