diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 6b60009ae41..39c56f38194 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -87,6 +87,7 @@ USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass); USE_MIR_PASS(__xpu__conv2d_fuse_pass); USE_MIR_PASS(__xpu__softmax_topk_fuse_pass); USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_fuse_pass); +USE_MIR_PASS(__xpu__roformer_relative_pos_fuse_pass); USE_MIR_PASS(__xpu__multi_encoder_slice_link_fuse_pass); USE_MIR_PASS(__xpu__generate_sequence_fuse_pass); USE_MIR_PASS(__xpu__logit_fuse_pass); diff --git a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index d3cfaaa5395..69dd0ea80ae 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -50,14 +50,16 @@ class XPUSingleEncoderFuser : public FuseBase { const std::string& matmul_type = "matmul", const std::string& mul_type = "mul", bool with_q_scale = true, - bool norm_before = false) + bool norm_before = false, + const std::string& relative_type = "") : act_type_(act_type), input_pos_(input_pos), qkv_ln_2_out_pos_(qkv_ln_2_out_pos), matmul_type_(matmul_type), mul_type_(mul_type), with_q_scale_(with_q_scale), - norm_before_(norm_before) {} + norm_before_(norm_before), + relative_emb_type_(relative_type) {} void BuildPattern() override { auto* input = VarNode("input") @@ -122,13 +124,36 @@ class XPUSingleEncoderFuser : public FuseBase { auto* q_transpose2 = OpNode("q_transpose2", "transpose2")->AsIntermediate(); auto* q_transpose2_out = VarNode("q_transpose2_out") ->assert_is_op_output("transpose2", "Out") - ->assert_is_op_input(target_op_type, "X") + ->assert_is_op_input(relative_emb_type_.empty() + ? target_op_type + : relative_emb_type_, + "X") ->AsIntermediate(); auto* q_transpose2_xshape = VarNode("q_transpose2_xshape") ->assert_is_op_output("transpose2", "XShape") ->AsIntermediate(); - + PMNode* q_relative_emb = nullptr; + PMNode* q_cos_embedding = nullptr; + PMNode* q_sin_embedding = nullptr; + PMNode* q_relative_emb_out = nullptr; + if (relative_emb_type_ == "__xpu__roformer_relative_embedding") { + VLOG(3) << "build q_relative_emb"; + q_relative_emb = + OpNode("q_relative_emb", relative_emb_type_)->AsIntermediate(); + q_sin_embedding = + VarNode("q_sin_embedding") + ->assert_is_op_input(relative_emb_type_, "SinEmbbeding") + ->AsInput(); + q_cos_embedding = + VarNode("q_cos_embedding") + ->assert_is_op_input(relative_emb_type_, "CosEmbbeding") + ->AsInput(); + q_relative_emb_out = VarNode("q_relative_emb_out") + ->assert_is_op_output(relative_emb_type_, "Out") + ->assert_is_op_input(target_op_type, "X") + ->AsIntermediate(); + } PMNode* q_scale = nullptr; PMNode* q_scale_out = nullptr; if (with_q_scale_) { @@ -165,8 +190,23 @@ class XPUSingleEncoderFuser : public FuseBase { auto* k_transpose2 = OpNode("k_transpose2", "transpose2")->AsIntermediate(); auto* k_transpose2_out = VarNode("k_transpose2_out") ->assert_is_op_output("transpose2", "Out") - ->assert_is_op_input(matmul_type_, "Y") ->AsIntermediate(); + PMNode* k_relative_emb = nullptr; + PMNode* k_sin_embedding = q_sin_embedding; + PMNode* k_cos_embedding = q_cos_embedding; + PMNode* k_relative_emb_out = nullptr; + if (relative_emb_type_.empty()) { + k_transpose2_out->assert_is_op_input(matmul_type_, "Y"); + } else if (relative_emb_type_ == "__xpu__roformer_relative_embedding") { + VLOG(3) << "build k_relative_emb"; + k_transpose2_out->assert_is_op_input(relative_emb_type_, "X"); + k_relative_emb = + OpNode("k_relative_emb", relative_emb_type_)->AsIntermediate(); + k_relative_emb_out = VarNode("k_relative_emb_out") + ->assert_is_op_output(relative_emb_type_, "Out") + ->assert_is_op_input(matmul_type_, "Y") + ->AsIntermediate(); + } auto* k_transpose2_xshape = VarNode("k_transpose2_xshape") ->assert_is_op_output("transpose2", "XShape") @@ -377,14 +417,23 @@ class XPUSingleEncoderFuser : public FuseBase { } else { *input >> *q_mul; } + *q_mul >> *q_mul_out >> *q_add >> *q_add_out >> *q_reshape2 >> + *q_reshape2_out >> *q_transpose2 >> *q_transpose2_out; + PMNode* last_node = q_transpose2_out; + if (relative_emb_type_ == "__xpu__roformer_relative_embedding") { + VLOG(3) << "build q_relative_emb link"; + *last_node >> *q_relative_emb >> *q_relative_emb_out; + *q_sin_embedding >> *q_relative_emb; + *q_cos_embedding >> *q_relative_emb; + last_node = q_relative_emb_out; + } if (with_q_scale_) { - *q_mul >> *q_mul_out >> *q_add >> *q_add_out >> *q_reshape2 >> - *q_reshape2_out >> *q_transpose2 >> *q_transpose2_out >> *q_scale >> - *q_scale_out >> *qk_matmul; - } else { - *q_mul >> *q_mul_out >> *q_add >> *q_add_out >> *q_reshape2 >> - *q_reshape2_out >> *q_transpose2 >> *q_transpose2_out >> *qk_matmul; + *last_node >> *q_scale >> *q_scale_out; + last_node = q_scale_out; } + *last_node >> *qk_matmul; + last_node = nullptr; + *q_mul_y >> *q_mul; *q_add_y >> *q_add; *q_reshape2 >> *q_reshape2_xshape; @@ -396,7 +445,16 @@ class XPUSingleEncoderFuser : public FuseBase { *input >> *k_mul; } *k_mul >> *k_mul_out >> *k_add >> *k_add_out >> *k_reshape2 >> - *k_reshape2_out >> *k_transpose2 >> *k_transpose2_out >> *qk_matmul; + *k_reshape2_out >> *k_transpose2 >> *k_transpose2_out; + last_node = k_transpose2_out; + if (relative_emb_type_ == "__xpu__roformer_relative_embedding") { + VLOG(3) << "build k_relative_emb link"; + *last_node >> *k_relative_emb >> *k_relative_emb_out; + *k_sin_embedding >> *k_relative_emb; + *k_cos_embedding >> *k_relative_emb; + last_node = k_relative_emb_out; + } + *last_node >> *qk_matmul; *k_mul_y >> *k_mul; *k_add_y >> *k_add; @@ -476,6 +534,9 @@ class XPUSingleEncoderFuser : public FuseBase { matched.at("qkv_add_3_y")->arg()->name, matched.at("qkv_add_4_y")->arg()->name, }); + VLOG(3) << "matched.at(q_add_y)->arg()->name: " + << matched.at("q_add_y")->arg()->name; + if (norm_before_) { op_desc.SetInput("LNScale", { @@ -546,7 +607,24 @@ class XPUSingleEncoderFuser : public FuseBase { op_desc.SetAttr("hidden_dim", hidden_dim); op_desc.SetAttr("act_type", act_type_); op_desc.SetAttr("norm_before", norm_before_); - + if (relative_emb_type_ == "__xpu__roformer_relative_embedding") { + // q/k share the rotary embedding + op_desc.SetInput("RoformerEmbedding", + {matched.at("q_cos_embedding")->arg()->name, + matched.at("q_sin_embedding")->arg()->name}); + op_desc.SetAttr("relative_type", 1); + auto q_relative_op = matched.at("q_relative_emb")->stmt()->op_info(); + auto q_cos_emb_shape = + scope->FindMutableTensor(q_relative_op->Input("CosEmbbeding").front()) + ->dims(); + CHECK_GE(q_cos_emb_shape.size(), 2) << q_cos_emb_shape.size(); + CHECK_EQ(size_per_head, q_cos_emb_shape[q_cos_emb_shape.size() - 1]); + int max_pos_len = q_cos_emb_shape[q_cos_emb_shape.size() - 2]; + VLOG(3) << "relative embedding max sequence len: " << max_pos_len; + op_desc.SetAttr("max_pos_len", max_pos_len); + } else { + op_desc.SetAttr("relative_type", 0); + } auto fake_subgraph_op = LiteOpRegistry::Global().Create("subgraph"); auto sub_program_desc = std::make_shared(); sub_program_desc->AddBlock(); @@ -573,6 +651,10 @@ class XPUSingleEncoderFuser : public FuseBase { "qkv_ln_2_scale", "qkv_ln_2_bias", }; + if (relative_emb_type_ == "__xpu__roformer_relative_embedding") { + froms.push_back("q_cos_embedding"); + froms.push_back("q_sin_embedding"); + } if (norm_before_) { froms.push_back("ln_before_scale"); froms.push_back("ln_before_bias"); @@ -599,6 +681,7 @@ class XPUSingleEncoderFuser : public FuseBase { std::string mul_type_; bool with_q_scale_; bool norm_before_; + const std::string relative_emb_type_; // quant_info: mul input_max, output_max * 6 + matmul x_max:y_max, output_max // * 2 void set_quant_info(Scope* scope, @@ -916,6 +999,15 @@ class XPUMultiEncoderFuser { arg_map[arg_name].push_back(name); } } + if ((i == 0) && (first_encoder_op_info->HasAttr("relative_type")) && + (first_encoder_op_info->GetAttr("relative_type") == 1)) { + CHECK_EQ(first_encoder_op_info->Input("RoformerEmbedding").size(), 2); + for (auto name : first_encoder_op_info->Input("RoformerEmbedding")) { + auto* arg_node = graph->RetrieveArgument(name); + DirectedLink(arg_node, first_encoder); + arg_map["RoformerEmbedding"].push_back(name); + } + } auto* cur_out = graph->RetrieveArgument(op_info->Output("Outputs").front()); @@ -950,6 +1042,14 @@ class XPUMultiEncoderFuser { op_desc.SetInput("Mask", {mask_name}); op_desc.SetOutput("Output", {out_name}); op_desc.SetAttr("xpu", 1); + op_desc.SetAttr( + "relative_type", + first_encoder_op_info->GetAttr("relative_type")); + if (first_encoder_op_info->GetAttr("relative_type") == 1 && + first_encoder_op_info->HasAttr("max_pos_len")) { + op_desc.SetAttr( + "max_pos_len", first_encoder_op_info->GetAttr("max_pos_len")); + } op_desc.SetAttr("norm_before", norm_before_0); op_desc.SetAttr("enable_int8", enable_int8); op_desc.SetAttr("enable_int16", enable_int16); @@ -1272,6 +1372,8 @@ class XPUMultiEncoderFusePass : public ProgramPass { std::vector mul_types{"mul", "matmul", "matmul_v2"}; std::vector with_q_scales{true, false}; std::vector norm_befores{true, false}; + std::vector relative_embedding_type{ + "", "__xpu__roformer_relative_embedding"}; std::string fc_precision; bool adaptive_seqlen = false; @@ -1311,18 +1413,21 @@ class XPUMultiEncoderFusePass : public ProgramPass { for (auto& mul_type : mul_types) { for (auto with_q_scale : with_q_scales) { for (auto norm_before : norm_befores) { - fusion::XPUSingleEncoderFuser single_encoder_fuser( - act_type, - input_pos, - qkv_ln_2_out_pos, - matmul_type, - mul_type, - with_q_scale, - norm_before); - single_encoder_fuser(graph.get()); - fusion::XPUMultiEncoderFuser multi_encoder_fuser( - fc_precision, adaptive_seqlen); - multi_encoder_fuser(graph.get()); + for (auto relative_type : relative_embedding_type) { + fusion::XPUSingleEncoderFuser single_encoder_fuser( + act_type, + input_pos, + qkv_ln_2_out_pos, + matmul_type, + mul_type, + with_q_scale, + norm_before, + relative_type); + single_encoder_fuser(graph.get()); + fusion::XPUMultiEncoderFuser multi_encoder_fuser( + fc_precision, adaptive_seqlen); + multi_encoder_fuser(graph.get()); + } } } } diff --git a/lite/core/optimizer/mir/fusion/__xpu__roformer_relative_pos_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__roformer_relative_pos_fuse_pass.cc new file mode 100644 index 00000000000..218a8496d4e --- /dev/null +++ b/lite/core/optimizer/mir/fusion/__xpu__roformer_relative_pos_fuse_pass.cc @@ -0,0 +1,195 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/xpu/math.h" +#include "lite/core/optimizer/mir/pass_registry.h" +#include "lite/core/optimizer/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +/* support xpu roformer relative pos */ +/* in_Input --------------- */ +/* | \ | */ +/* | \ | */ +/* split shape | */ +/* / | \ | */ +/* / | \ | */ +/* | scale slice | */ +/* \ | / \ | */ +/* \ | / \ | */ +/* concat slice slice | */ +/* | / \ | */ +/* | / \ | */ +/* elementwise_mul elementwise_mul */ +/* | / */ +/* | / */ +/* elementwise_add */ +/* | */ +/* | */ +/* out_Output */ +/*-------------------------------------------*/ +/* After the pass apply: */ +/* in_Input */ +/* cos_emb | sin_emb */ +/* \ | / */ +/* xpu_roformer_relative */ +/* | */ +/* | */ +/* out_Output */ +/*-------------------------------------------*/ + +class XPURoformerRelativePosFuser : public FuseBase { + public: + void BuildPattern() override { + auto* input = VarNode("input") + ->assert_is_op_input("split", "X") + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_input("shape", "Input") + ->AsInput(); + auto* split = + OpNode("split", "split") + ->assert_op_attr("axis", 3) + ->assert_op_attr("num", 2) // do we really need it + ->AsIntermediate(); + auto* split_out0 = VarNode("split_out0") + ->assert_is_op_nth_input("concat", "X", 1) + ->assert_is_op_nth_output("split", "Out", 0) + ->AsIntermediate(); + auto* split_out1 = VarNode("split_out1") + ->assert_is_op_input("scale", "X") + ->assert_is_op_nth_output("split", "Out", 1) + ->AsIntermediate(); + auto* scale = + OpNode("scale", "scale") + ->assert_op_attr_satisfied( + "scale", + [](float attr) { return (std::fabs(attr + 1.0) < 1e-5); }) + ->AsIntermediate(); + auto* scale_out = VarNode("scale_out") + ->assert_is_op_input("concat", "X") + ->assert_is_op_output("scale", "Out") + ->AsIntermediate(); + auto* concat = OpNode("concat", "concat")->AsIntermediate(); + auto* concat_out = VarNode("concat_out") + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_output("concat", "Out") + ->AsIntermediate(); + auto* shape = OpNode("shape", "shape")->AsIntermediate(); + auto* shape_out = VarNode("shape_out") + ->assert_is_op_input("slice", "Input") + ->assert_is_op_output("shape", "Out") + ->AsIntermediate(); + auto* slice1 = OpNode("slice1", "slice")->AsIntermediate(); + auto* slice1_out = VarNode("slice1_out") + ->assert_is_op_input("slice", "EndsTensorList") + ->assert_is_op_output("slice", "Out") + ->AsIntermediate(); + auto* sin_emb = + VarNode("sin_emb")->assert_is_op_input("slice", "Input")->AsInput(); + auto* cos_emb = + VarNode("cos_emb")->assert_is_op_input("slice", "Input")->AsInput(); + auto* slice_sin = OpNode("slice_sin", "slice")->AsIntermediate(); + auto* slice_sin_out = VarNode("slice_sin_out") + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_op_output("slice", "Out") + ->AsIntermediate(); + auto* ew_mul_sin = + OpNode("ew_mul_sin", "elementwise_mul")->AsIntermediate(); + auto* ew_mul_sin_out = VarNode("ew_mul_sin_out") + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_op_output("elementwise_mul", "Out") + ->AsIntermediate(); + auto* ew_add = OpNode("ew_add", "elementwise_add")->AsIntermediate(); + auto* ew_add_out = VarNode("ew_add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->AsOutput(); + auto* slice_cos = OpNode("slice_cos", "slice")->AsIntermediate(); + auto* slice_cos_out = VarNode("slice_cos_out") + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_op_output("slice", "Out") + ->AsIntermediate(); + auto* ew_mul_cos = + OpNode("ew_mul_cos", "elementwise_mul")->AsIntermediate(); + auto* ew_mul_cos_out = VarNode("ew_mul_cos_out") + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_output("elementwise_mul", "Out") + ->AsIntermediate(); + *input >> *split >> *split_out1 >> *scale >> *scale_out >> *concat >> + *concat_out >> *ew_mul_sin >> *ew_mul_sin_out >> *ew_add >> *ew_add_out; + *input >> *ew_mul_cos >> *ew_mul_cos_out >> *ew_add; + *input >> *shape >> *shape_out >> *slice1 >> *slice1_out >> *slice_sin >> + *slice_sin_out >> *ew_mul_sin; + *slice1_out >> *slice_cos >> *slice_cos_out >> *ew_mul_cos; + *sin_emb >> *slice_sin; + *cos_emb >> *slice_cos; + *split >> *split_out0 >> *concat; + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("__xpu__roformer_relative_embedding"); + // use "X", be consistent with target_op_type_ in multiencoder pass + op_desc.SetInput("X", {matched.at("input")->arg()->name}); + op_desc.SetInput("CosEmbbeding", {matched.at("cos_emb")->arg()->name}); + op_desc.SetInput("SinEmbbeding", {matched.at("sin_emb")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("ew_add_out")->arg()->name}); + auto* scope = matched.at("split")->stmt()->op()->scope(); + + auto cos_emb_name = matched.at("cos_emb")->arg()->name; + auto cos_emb_shape = scope->FindMutableTensor(cos_emb_name)->dims(); + auto sin_emb_name = matched.at("sin_emb")->arg()->name; + auto sin_emb_shape = scope->FindMutableTensor(sin_emb_name)->dims(); + CHECK_EQ(cos_emb_shape.size(), 4) << cos_emb_shape.size(); + CHECK_GT(cos_emb_shape[2], 0) << cos_emb_shape[2]; + CHECK_EQ(sin_emb_shape.size(), 4) << sin_emb_shape.size(); + for (int i = 0; i < sin_emb_shape.size(); ++i) { + CHECK_EQ(sin_emb_shape[i], cos_emb_shape[i]) + << i << " th dim: " << sin_emb_shape[i] << ", " << cos_emb_shape[i]; + } + op_desc.SetAttr("max_pos_len", cos_emb_shape[2]); + + auto& valid_places = matched.at("split")->stmt()->op()->valid_places(); + auto new_op = LiteOpRegistry::Global().Create(op_desc.Type()); + new_op->Attach(op_desc, scope); + auto* new_op_node = graph->GraphCreateInstructNode(new_op, valid_places); + + DirectedLink(matched.at("input"), new_op_node); + DirectedLink(matched.at("cos_emb"), new_op_node); + DirectedLink(matched.at("sin_emb"), new_op_node); + DirectedLink(new_op_node, matched.at("ew_add_out")); + } +}; + +} // namespace fusion + +class XPURoformerRelativePosFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + fusion::XPURoformerRelativePosFuser fuser; + fuser(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(__xpu__roformer_relative_pos_fuse_pass, + paddle::lite::mir::XPURoformerRelativePosFusePass) + .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/optimizer/optimizer.cc b/lite/core/optimizer/optimizer.cc index 60f37c11585..65bd23cf167 100644 --- a/lite/core/optimizer/optimizer.cc +++ b/lite/core/optimizer/optimizer.cc @@ -199,6 +199,7 @@ std::unique_ptr RunDefaultOptimizer( "__xpu__squeeze_excitation_fuse_pass", "__xpu__mmdnn_fuse_pass", "__xpu__bigru_fuse_pass", + "__xpu__roformer_relative_pos_fuse_pass", "__xpu__multi_encoder_fuse_pass", "__xpu__embedding_with_eltwise_add_fuse_pass", "__xpu__fc_fuse_pass", diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt index 61aa1871db3..5e746ed5776 100644 --- a/lite/kernels/xpu/CMakeLists.txt +++ b/lite/kernels/xpu/CMakeLists.txt @@ -117,6 +117,7 @@ add_kernel(__xpu__conv2d_compute_xpu XPU extra SRCS __xpu__conv2d_compute.cc) add_kernel(__xpu__softmax_topk_compute_xpu XPU extra SRCS __xpu__softmax_topk_compute.cc) add_kernel(__xpu__generate_sequence_compute_xpu XPU extra SRCS __xpu__generate_sequence_compute.cc) add_kernel(__xpu__logit_compute_xpu XPU extra SRCS __xpu__logit_compute.cc) +add_kernel(__xpu__roformer_relative_embedding_compute_xpu XPU extra SRCS __xpu__roformer_relative_embedding_compute.cc) add_kernel(__xpu__squeeze_excitation_compute_xpu XPU extra SRCS __xpu__squeeze_excitation_compute.cc) add_kernel(__xpu__bigru_compute_xpu XPU extra SRCS __xpu__bigru_compute.cc) add_kernel(__xpu__dynamic_lstm_compute_xpu XPU extra SRCS __xpu__dynamic_lstm_compute.cc) diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index ef3c446afa0..4eef89b09bd 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -155,6 +155,13 @@ void XPUMultiEncoderCompute::PrepareForRun() { for (auto* ln_bias : param.ln_bias) { arg_ln_bias_.push_back(ln_bias->data()); } + relative_type_ = param.relative_type; + // prepare roformer embedding + if (relative_type_ == 1) { + for (auto* emb : param.roformer_embedding) { + roformer_embedding_.push_back(emb->data()); + } + } // prepare weights local_quant_ = GetBoolFromEnv("XPU_LOCAL_QUANT") || lite::TargetWrapperXPU::local_quant; @@ -226,6 +233,12 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { param.norm_before, /*is_pre_norm*/ param.per_channel); qkv_attn_param.quant_type_.assign(quant_types_.begin(), quant_types_.end()); + if (relative_type_ == 1) { + qkv_attn_param.relative_type = relative_type_; + qkv_attn_param.max_pos_len = param.max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_.begin(), + roformer_embedding_.end()); + } if (std::is_same::value) { CHECK_GT(fc_input_max_.size(), 0); @@ -259,8 +272,15 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) { slice_idx, true, param.hidden_dim, - param.norm_before); + param.norm_before, + param.per_channel); qkv_attn_param.quant_type_.assign(quant_types_.begin(), quant_types_.end()); + if (relative_type_ == 1) { + qkv_attn_param.relative_type = relative_type_; + qkv_attn_param.max_pos_len = param.max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_.begin(), + roformer_embedding_.end()); + } int r = xdnn::transformer_encoder( ctx.GetRawContext(), in, @@ -367,6 +387,7 @@ REGISTER_LITE_KERNEL(__xpu__multi_encoder, .BindInput("FCBias", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("LNScale", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("LNBias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("RoformerEmbedding", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("Mask", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.h b/lite/kernels/xpu/__xpu__multi_encoder_compute.h index 886094cf678..9a4b070bf86 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.h +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.h @@ -43,6 +43,7 @@ class XPUMultiEncoderCompute std::vector arg_ln_bias_; std::vector fc_weight_max_; std::vector fc_input_max_; + std::vector roformer_embedding_; std::vector quant_types_; XPUScratchPadGuard weight_max_guard_; XPUScratchPadGuard input_max_guard_; @@ -50,6 +51,7 @@ class XPUMultiEncoderCompute XPUScratchPadGuard cast_out_guard_; xdnn::Activation_t qkv_act = xdnn::Activation_t::RELU; int slice_idx = -1; + int relative_type_ = 0; bool local_quant_ = false; template diff --git a/lite/kernels/xpu/__xpu__roformer_relative_embedding_compute.cc b/lite/kernels/xpu/__xpu__roformer_relative_embedding_compute.cc new file mode 100644 index 00000000000..d98ebfb4c6e --- /dev/null +++ b/lite/kernels/xpu/__xpu__roformer_relative_embedding_compute.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/__xpu__roformer_relative_embedding_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void RoformerRelativeEmbeddingCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + auto input_dim = param.input->dims(); + CHECK_EQ(input_dim.size(), 4); + int batch = input_dim[0]; + int head_num = param.input->dims()[1]; + int seqlen = param.input->dims()[2]; + int head_dim = param.input->dims()[3]; + CHECK_LE(seqlen, param.max_pos_len); + std::vector lod; + lod.resize(batch + 1); + for (int i = 0; i < batch + 1; i++) { + lod[i] = i * seqlen; + } + int r = + xdnn::rope(ctx.GetRawContext(), + param.input->data(), + param.output->mutable_data(TARGET(kXPU)), + param.cos_embedding->data(), + param.sin_embedding->data(), + batch, + head_num, + head_dim, + head_num * head_dim, + lod, + param.max_pos_len, + false, // no vsl + true); // transpose to [n, seql, head_num, head_dim] + CHECK_EQ(r, 0) << "call RoformerRelativeEmbeddingCompute failed"; +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + __xpu__roformer_relative_embedding, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::RoformerRelativeEmbeddingCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("CosEmbbeding", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("SinEmbbeding", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/__xpu__roformer_relative_embedding_compute.h b/lite/kernels/xpu/__xpu__roformer_relative_embedding_compute.h new file mode 100644 index 00000000000..f2894605171 --- /dev/null +++ b/lite/kernels/xpu/__xpu__roformer_relative_embedding_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class RoformerRelativeEmbeddingCompute + : public KernelLite { + public: + using param_t = operators::XPURoformerRelativeEmbeddingParam; + + virtual void Run(); + + virtual ~RoformerRelativeEmbeddingCompute() = default; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index b40c63c1328..c0ec43f334f 100755 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -230,6 +230,7 @@ add_operator(__xpu__softmax_topk_op extra SRCS __xpu__softmax_topk_op.cc) add_operator(__xpu__multi_encoder_op extra SRCS __xpu__multi_encoder_op.cc) add_operator(__xpu__embedding_with_eltwise_add_op extra SRCS __xpu__embedding_with_eltwise_add_op.cc) add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc) +add_operator(__xpu__roformer_relative_embedding_op extra SRCS __xpu__roformer_relative_embedding_op.cc) add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc) add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc) add_operator(__xpu__conv2d_op extra SRCS __xpu__conv2d_op.cc) diff --git a/lite/operators/__xpu__multi_encoder_op.cc b/lite/operators/__xpu__multi_encoder_op.cc index 019a5e0632d..edce2fd946f 100644 --- a/lite/operators/__xpu__multi_encoder_op.cc +++ b/lite/operators/__xpu__multi_encoder_op.cc @@ -71,35 +71,40 @@ bool XPUMultiEncoderOp::InferShapeImpl() const { bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { - param_.input = const_cast( - &scope->FindVar(op_desc.Input("Input").front())->Get()); - param_.output = scope->FindVar(op_desc.Output("Output").front()) - ->GetMutable(); + param_.input = + scope->FindVar(op_desc.Input("Input").front())->GetMutable(); + param_.output = + scope->FindVar(op_desc.Output("Output").front())->GetMutable(); + param_.relative_type = op_desc.GetAttr("relative_type"); param_.fc_weight.clear(); for (auto& name : op_desc.Input("FCWeight")) { - auto t = - const_cast(&scope->FindVar(name)->Get()); + auto t = scope->FindVar(name)->GetMutable(); param_.fc_weight.push_back(t); } param_.fc_bias.clear(); for (auto& name : op_desc.Input("FCBias")) { - auto t = - const_cast(&scope->FindVar(name)->Get()); + auto t = scope->FindVar(name)->GetMutable(); param_.fc_bias.push_back(t); } param_.ln_scale.clear(); for (auto& name : op_desc.Input("LNScale")) { - auto t = - const_cast(&scope->FindVar(name)->Get()); + auto t = scope->FindVar(name)->GetMutable(); param_.ln_scale.push_back(t); } param_.ln_bias.clear(); for (auto& name : op_desc.Input("LNBias")) { - auto t = - const_cast(&scope->FindVar(name)->Get()); + auto t = scope->FindVar(name)->GetMutable(); param_.ln_bias.push_back(t); } + param_.roformer_embedding.clear(); + if (param_.relative_type == 1) { + param_.max_pos_len = op_desc.GetAttr("max_pos_len"); + for (auto& name : op_desc.Input("RoformerEmbedding")) { + auto t = scope->FindMutableTensor(name); + param_.roformer_embedding.push_back(t); + } + } std::vector input_arg_names = op_desc.InputArgumentNames(); if (std::find(input_arg_names.begin(), input_arg_names.end(), "SeqLod") != @@ -108,7 +113,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, if (arguments.size() > 0) { auto arg_var = scope->FindVar(arguments.front()); if (arg_var != nullptr) { - param_.SeqLod = &(arg_var->Get()); + param_.SeqLod = &(arg_var->Get()); } } } @@ -118,7 +123,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, if (arguments.size() > 0) { auto arg_var = scope->FindVar(arguments.front()); if (arg_var != nullptr) { - param_.PadSeqLen = &(arg_var->Get()); + param_.PadSeqLen = &(arg_var->Get()); } } } @@ -128,7 +133,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, if (arguments.size() > 0) { auto arg_var = scope->FindVar(arguments.front()); if (arg_var != nullptr) { - param_.mask = &(arg_var->Get()); + param_.mask = &(arg_var->Get()); } } } diff --git a/lite/operators/__xpu__roformer_relative_embedding_op.cc b/lite/operators/__xpu__roformer_relative_embedding_op.cc new file mode 100644 index 00000000000..2d656891baf --- /dev/null +++ b/lite/operators/__xpu__roformer_relative_embedding_op.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/__xpu__roformer_relative_embedding_op.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool XPURoformerRelativeEmbeddingOp::CheckShape() const { + CHECK_OR_FALSE(param_.input); + CHECK_OR_FALSE(param_.output); + CHECK_OR_FALSE(param_.cos_embedding); + CHECK_OR_FALSE(param_.sin_embedding); + + const auto input_dims = param_.input->dims(); + const auto cos_emb_dims = param_.cos_embedding->dims(); + const auto sin_emb_dims = param_.sin_embedding->dims(); + CHECK_EQ_OR_FALSE(input_dims.size(), 4UL); + CHECK_EQ_OR_FALSE(cos_emb_dims.size(), 4UL); + CHECK_EQ_OR_FALSE(sin_emb_dims.size(), 4UL); + for (int i = 0; i < cos_emb_dims.size(); ++i) { + CHECK_EQ(cos_emb_dims[i], sin_emb_dims[i]) << i << " dim embedding unmatch " + << cos_emb_dims[i] << ", " + << sin_emb_dims[i]; + } + CHECK_EQ(input_dims[3], cos_emb_dims[3]) << input_dims[3] << ", " + << cos_emb_dims[3]; + return true; +} + +bool XPURoformerRelativeEmbeddingOp::InferShapeImpl() const { + const auto& input_dims = param_.input->dims(); + param_.output->Resize(input_dims); + // share LoD + param_.output->set_lod(param_.input->lod()); + + return true; +} + +bool XPURoformerRelativeEmbeddingOp::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + param_.input = + scope->FindVar(op_desc.Input("X").front())->GetMutable(); + param_.cos_embedding = scope->FindVar(op_desc.Input("CosEmbbeding").front()) + ->GetMutable(); + param_.sin_embedding = scope->FindVar(op_desc.Input("SinEmbbeding").front()) + ->GetMutable(); + param_.output = + scope->FindVar(op_desc.Output("Out").front())->GetMutable(); + param_.max_pos_len = op_desc.GetAttr("max_pos_len"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(__xpu__roformer_relative_embedding, + paddle::lite::operators::XPURoformerRelativeEmbeddingOp); diff --git a/lite/operators/__xpu__roformer_relative_embedding_op.h b/lite/operators/__xpu__roformer_relative_embedding_op.h new file mode 100644 index 00000000000..3927548ffbe --- /dev/null +++ b/lite/operators/__xpu__roformer_relative_embedding_op.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class XPURoformerRelativeEmbeddingOp : public OpLite { + public: + XPURoformerRelativeEmbeddingOp() {} + + explicit XPURoformerRelativeEmbeddingOp(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "XPURoformerRelativeEmbedding"; + } + + private: + mutable XPURoformerRelativeEmbeddingParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index a9d084bce99..bfb909f750a 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1729,6 +1729,7 @@ struct XPUMultiEncoderParam : ParamBase { std::vector fc_bias; std::vector ln_scale; std::vector ln_bias; + std::vector roformer_embedding; const lite::Tensor* mask{nullptr}; const lite::Tensor* SeqLod{nullptr}; const lite::Tensor* PadSeqLen{nullptr}; @@ -1746,6 +1747,8 @@ struct XPUMultiEncoderParam : ParamBase { int size_per_head{}; int hidden_dim{}; std::string act_type{}; + int relative_type{0}; + int max_pos_len{512}; // relative embedding [max_pos_len, head_dim] std::string precision{}; bool enable_qkv_fusion{false}; bool norm_before{false}; @@ -1790,6 +1793,14 @@ struct XPUFcParam : ParamBase { float alpha{1.0f}; }; +struct XPURoformerRelativeEmbeddingParam : ParamBase { + lite::Tensor* input{nullptr}; + lite::Tensor* cos_embedding{nullptr}; + lite::Tensor* sin_embedding{nullptr}; + lite::Tensor* output{nullptr}; + int max_pos_len{512}; +}; + struct XPUResNetCbamParam : ParamBase { lite::Tensor* input{}; std::vector filter;