Skip to content

Commit

Permalink
[XPU] support roformer relative embedding (PaddlePaddle#9536)
Browse files Browse the repository at this point in the history
  • Loading branch information
newway committed Nov 19, 2022
1 parent db87f2b commit 62c6b08
Show file tree
Hide file tree
Showing 14 changed files with 617 additions and 41 deletions.
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass);
USE_MIR_PASS(__xpu__conv2d_fuse_pass);
USE_MIR_PASS(__xpu__softmax_topk_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_fuse_pass);
USE_MIR_PASS(__xpu__roformer_relative_pos_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_slice_link_fuse_pass);
USE_MIR_PASS(__xpu__generate_sequence_fuse_pass);
USE_MIR_PASS(__xpu__logit_fuse_pass);
Expand Down
155 changes: 130 additions & 25 deletions lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ class XPUSingleEncoderFuser : public FuseBase {
const std::string& matmul_type = "matmul",
const std::string& mul_type = "mul",
bool with_q_scale = true,
bool norm_before = false)
bool norm_before = false,
const std::string& relative_type = "")
: act_type_(act_type),
input_pos_(input_pos),
qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
matmul_type_(matmul_type),
mul_type_(mul_type),
with_q_scale_(with_q_scale),
norm_before_(norm_before) {}
norm_before_(norm_before),
relative_emb_type_(relative_type) {}

void BuildPattern() override {
auto* input = VarNode("input")
Expand Down Expand Up @@ -122,13 +124,36 @@ class XPUSingleEncoderFuser : public FuseBase {
auto* q_transpose2 = OpNode("q_transpose2", "transpose2")->AsIntermediate();
auto* q_transpose2_out = VarNode("q_transpose2_out")
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(target_op_type, "X")
->assert_is_op_input(relative_emb_type_.empty()
? target_op_type
: relative_emb_type_,
"X")
->AsIntermediate();
auto* q_transpose2_xshape =
VarNode("q_transpose2_xshape")
->assert_is_op_output("transpose2", "XShape")
->AsIntermediate();

PMNode* q_relative_emb = nullptr;
PMNode* q_cos_embedding = nullptr;
PMNode* q_sin_embedding = nullptr;
PMNode* q_relative_emb_out = nullptr;
if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
VLOG(3) << "build q_relative_emb";
q_relative_emb =
OpNode("q_relative_emb", relative_emb_type_)->AsIntermediate();
q_sin_embedding =
VarNode("q_sin_embedding")
->assert_is_op_input(relative_emb_type_, "SinEmbbeding")
->AsInput();
q_cos_embedding =
VarNode("q_cos_embedding")
->assert_is_op_input(relative_emb_type_, "CosEmbbeding")
->AsInput();
q_relative_emb_out = VarNode("q_relative_emb_out")
->assert_is_op_output(relative_emb_type_, "Out")
->assert_is_op_input(target_op_type, "X")
->AsIntermediate();
}
PMNode* q_scale = nullptr;
PMNode* q_scale_out = nullptr;
if (with_q_scale_) {
Expand Down Expand Up @@ -165,8 +190,23 @@ class XPUSingleEncoderFuser : public FuseBase {
auto* k_transpose2 = OpNode("k_transpose2", "transpose2")->AsIntermediate();
auto* k_transpose2_out = VarNode("k_transpose2_out")
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_, "Y")
->AsIntermediate();
PMNode* k_relative_emb = nullptr;
PMNode* k_sin_embedding = q_sin_embedding;
PMNode* k_cos_embedding = q_cos_embedding;
PMNode* k_relative_emb_out = nullptr;
if (relative_emb_type_.empty()) {
k_transpose2_out->assert_is_op_input(matmul_type_, "Y");
} else if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
VLOG(3) << "build k_relative_emb";
k_transpose2_out->assert_is_op_input(relative_emb_type_, "X");
k_relative_emb =
OpNode("k_relative_emb", relative_emb_type_)->AsIntermediate();
k_relative_emb_out = VarNode("k_relative_emb_out")
->assert_is_op_output(relative_emb_type_, "Out")
->assert_is_op_input(matmul_type_, "Y")
->AsIntermediate();
}
auto* k_transpose2_xshape =
VarNode("k_transpose2_xshape")
->assert_is_op_output("transpose2", "XShape")
Expand Down Expand Up @@ -377,14 +417,23 @@ class XPUSingleEncoderFuser : public FuseBase {
} else {
*input >> *q_mul;
}
*q_mul >> *q_mul_out >> *q_add >> *q_add_out >> *q_reshape2 >>
*q_reshape2_out >> *q_transpose2 >> *q_transpose2_out;
PMNode* last_node = q_transpose2_out;
if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
VLOG(3) << "build q_relative_emb link";
*last_node >> *q_relative_emb >> *q_relative_emb_out;
*q_sin_embedding >> *q_relative_emb;
*q_cos_embedding >> *q_relative_emb;
last_node = q_relative_emb_out;
}
if (with_q_scale_) {
*q_mul >> *q_mul_out >> *q_add >> *q_add_out >> *q_reshape2 >>
*q_reshape2_out >> *q_transpose2 >> *q_transpose2_out >> *q_scale >>
*q_scale_out >> *qk_matmul;
} else {
*q_mul >> *q_mul_out >> *q_add >> *q_add_out >> *q_reshape2 >>
*q_reshape2_out >> *q_transpose2 >> *q_transpose2_out >> *qk_matmul;
*last_node >> *q_scale >> *q_scale_out;
last_node = q_scale_out;
}
*last_node >> *qk_matmul;
last_node = nullptr;

*q_mul_y >> *q_mul;
*q_add_y >> *q_add;
*q_reshape2 >> *q_reshape2_xshape;
Expand All @@ -396,7 +445,16 @@ class XPUSingleEncoderFuser : public FuseBase {
*input >> *k_mul;
}
*k_mul >> *k_mul_out >> *k_add >> *k_add_out >> *k_reshape2 >>
*k_reshape2_out >> *k_transpose2 >> *k_transpose2_out >> *qk_matmul;
*k_reshape2_out >> *k_transpose2 >> *k_transpose2_out;
last_node = k_transpose2_out;
if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
VLOG(3) << "build k_relative_emb link";
*last_node >> *k_relative_emb >> *k_relative_emb_out;
*k_sin_embedding >> *k_relative_emb;
*k_cos_embedding >> *k_relative_emb;
last_node = k_relative_emb_out;
}
*last_node >> *qk_matmul;

*k_mul_y >> *k_mul;
*k_add_y >> *k_add;
Expand Down Expand Up @@ -476,6 +534,9 @@ class XPUSingleEncoderFuser : public FuseBase {
matched.at("qkv_add_3_y")->arg()->name,
matched.at("qkv_add_4_y")->arg()->name,
});
VLOG(3) << "matched.at(q_add_y)->arg()->name: "
<< matched.at("q_add_y")->arg()->name;

if (norm_before_) {
op_desc.SetInput("LNScale",
{
Expand Down Expand Up @@ -546,7 +607,24 @@ class XPUSingleEncoderFuser : public FuseBase {
op_desc.SetAttr<int>("hidden_dim", hidden_dim);
op_desc.SetAttr<std::string>("act_type", act_type_);
op_desc.SetAttr<bool>("norm_before", norm_before_);

if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
// q/k share the rotary embedding
op_desc.SetInput("RoformerEmbedding",
{matched.at("q_cos_embedding")->arg()->name,
matched.at("q_sin_embedding")->arg()->name});
op_desc.SetAttr<int>("relative_type", 1);
auto q_relative_op = matched.at("q_relative_emb")->stmt()->op_info();
auto q_cos_emb_shape =
scope->FindMutableTensor(q_relative_op->Input("CosEmbbeding").front())
->dims();
CHECK_GE(q_cos_emb_shape.size(), 2) << q_cos_emb_shape.size();
CHECK_EQ(size_per_head, q_cos_emb_shape[q_cos_emb_shape.size() - 1]);
int max_pos_len = q_cos_emb_shape[q_cos_emb_shape.size() - 2];
VLOG(3) << "relative embedding max sequence len: " << max_pos_len;
op_desc.SetAttr<int>("max_pos_len", max_pos_len);
} else {
op_desc.SetAttr<int>("relative_type", 0);
}
auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph");
auto sub_program_desc = std::make_shared<cpp::ProgramDesc>();
sub_program_desc->AddBlock<cpp::BlockDesc>();
Expand All @@ -573,6 +651,10 @@ class XPUSingleEncoderFuser : public FuseBase {
"qkv_ln_2_scale",
"qkv_ln_2_bias",
};
if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
froms.push_back("q_cos_embedding");
froms.push_back("q_sin_embedding");
}
if (norm_before_) {
froms.push_back("ln_before_scale");
froms.push_back("ln_before_bias");
Expand All @@ -599,6 +681,7 @@ class XPUSingleEncoderFuser : public FuseBase {
std::string mul_type_;
bool with_q_scale_;
bool norm_before_;
const std::string relative_emb_type_;
// quant_info: mul input_max, output_max * 6 + matmul x_max:y_max, output_max
// * 2
void set_quant_info(Scope* scope,
Expand Down Expand Up @@ -916,6 +999,15 @@ class XPUMultiEncoderFuser {
arg_map[arg_name].push_back(name);
}
}
if ((i == 0) && (first_encoder_op_info->HasAttr("relative_type")) &&
(first_encoder_op_info->GetAttr<int>("relative_type") == 1)) {
CHECK_EQ(first_encoder_op_info->Input("RoformerEmbedding").size(), 2);
for (auto name : first_encoder_op_info->Input("RoformerEmbedding")) {
auto* arg_node = graph->RetrieveArgument(name);
DirectedLink(arg_node, first_encoder);
arg_map["RoformerEmbedding"].push_back(name);
}
}

auto* cur_out =
graph->RetrieveArgument(op_info->Output("Outputs").front());
Expand Down Expand Up @@ -950,6 +1042,14 @@ class XPUMultiEncoderFuser {
op_desc.SetInput("Mask", {mask_name});
op_desc.SetOutput("Output", {out_name});
op_desc.SetAttr<int>("xpu", 1);
op_desc.SetAttr<int>(
"relative_type",
first_encoder_op_info->GetAttr<int>("relative_type"));
if (first_encoder_op_info->GetAttr<int>("relative_type") == 1 &&
first_encoder_op_info->HasAttr("max_pos_len")) {
op_desc.SetAttr<int>(
"max_pos_len", first_encoder_op_info->GetAttr<int>("max_pos_len"));
}
op_desc.SetAttr<bool>("norm_before", norm_before_0);
op_desc.SetAttr<bool>("enable_int8", enable_int8);
op_desc.SetAttr<bool>("enable_int16", enable_int16);
Expand Down Expand Up @@ -1272,6 +1372,8 @@ class XPUMultiEncoderFusePass : public ProgramPass {
std::vector<std::string> mul_types{"mul", "matmul", "matmul_v2"};
std::vector<bool> with_q_scales{true, false};
std::vector<bool> norm_befores{true, false};
std::vector<std::string> relative_embedding_type{
"", "__xpu__roformer_relative_embedding"};

std::string fc_precision;
bool adaptive_seqlen = false;
Expand Down Expand Up @@ -1311,18 +1413,21 @@ class XPUMultiEncoderFusePass : public ProgramPass {
for (auto& mul_type : mul_types) {
for (auto with_q_scale : with_q_scales) {
for (auto norm_before : norm_befores) {
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
mul_type,
with_q_scale,
norm_before);
single_encoder_fuser(graph.get());
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
fc_precision, adaptive_seqlen);
multi_encoder_fuser(graph.get());
for (auto relative_type : relative_embedding_type) {
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
mul_type,
with_q_scale,
norm_before,
relative_type);
single_encoder_fuser(graph.get());
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
fc_precision, adaptive_seqlen);
multi_encoder_fuser(graph.get());
}
}
}
}
Expand Down
Loading

0 comments on commit 62c6b08

Please sign in to comment.