From 54525ab037f566138dc75409649a7cb7cb565c40 Mon Sep 17 00:00:00 2001 From: laiou <46396912+laiou@users.noreply.github.com> Date: Mon, 17 Jan 2022 16:12:40 +0800 Subject: [PATCH] [XPU] pad3d and memory pass (#8213) --- .../optimizer/mir/fusion/inplace_fuser.cc | 23 +- .../optimizer/mir/xpu_memory_optimize_pass.cc | 282 +++++++++++++----- .../optimizer/mir/xpu_memory_optimize_pass.h | 15 +- lite/kernels/xpu/CMakeLists.txt | 1 + lite/kernels/xpu/pad3d_compute.cc | 101 +++++++ lite/kernels/xpu/pad3d_compute.h | 37 +++ 6 files changed, 377 insertions(+), 82 deletions(-) create mode 100644 lite/kernels/xpu/pad3d_compute.cc create mode 100644 lite/kernels/xpu/pad3d_compute.h diff --git a/lite/core/optimizer/mir/fusion/inplace_fuser.cc b/lite/core/optimizer/mir/fusion/inplace_fuser.cc index 89399af8af7..9cef60c09e7 100644 --- a/lite/core/optimizer/mir/fusion/inplace_fuser.cc +++ b/lite/core/optimizer/mir/fusion/inplace_fuser.cc @@ -15,22 +15,31 @@ #include "lite/core/optimizer/mir/fusion/inplace_fuser.h" #include #include +#include "lite/core/optimizer/mir/pattern_matcher_high_api.h" namespace paddle { namespace lite { namespace mir { namespace fusion { -void InplaceFuser::BuildPattern() { OpNode("inplace", type_); } +void InplaceFuser::BuildPattern() { + auto* input = VarNode("input") + ->assert_is_op_input(type_, "X") + ->assert_only_one_output() + ->AsInput(); + + auto* op_node = OpNode("inplace", type_)->assert_is_op(type_); + + auto* output = VarNode("output") + ->assert_is_op_output(type_, "Out") + ->assert_only_one_output() + ->AsOutput(); + + *input >> *op_node >> *output; +} void InplaceFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { - auto out_var_nodes = matched.at("inplace")->outlinks; bool inplace = true; - for (auto& out_var_node : out_var_nodes) { - if (out_var_node->outlinks.size() > 1) { - inplace = false; - } - } auto* stmt = matched.at("inplace")->stmt(); auto op = stmt->op(); cpp::OpDesc* op_desc = op->mutable_op_info(); diff --git a/lite/core/optimizer/mir/xpu_memory_optimize_pass.cc b/lite/core/optimizer/mir/xpu_memory_optimize_pass.cc index 46c415fcf82..2070fdfa356 100644 --- a/lite/core/optimizer/mir/xpu_memory_optimize_pass.cc +++ b/lite/core/optimizer/mir/xpu_memory_optimize_pass.cc @@ -30,11 +30,15 @@ typedef struct { int cluster; std::pair lifetime; int life_interval; + int mapping; std::set adj; } XPUMemNode; void XPUMemoryOptimizePass::CollectLifeCycleByDevice( - std::map* lifecycles, SSAGraph* graph) { + std::map* lifecycles, + SSAGraph* graph, + std::map* inplaceop_input2output, + std::map* inplaceop_output2input) { max_lifecycle_ = 0; auto is_host = [](TargetType x) -> bool { @@ -93,9 +97,7 @@ void XPUMemoryOptimizePass::CollectLifeCycleByDevice( } }; - VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size(); insert_invalid_op_nodes_for_specific_target(invalid_op_nodes); - VLOG(4) << "invalid_op_nodes.size();" << invalid_op_nodes.size(); // Collect the invalid input and output variables that will not be reused. std::set invalid_var_names; @@ -116,36 +118,6 @@ void XPUMemoryOptimizePass::CollectLifeCycleByDevice( } continue; } - // The specified input and output variables of the Ops whose 'inplace' attr - // is true will not be reused, such as reshape/reshape2's X and Out - // variables - std::map, std::set>> - inplace_op_nodes = {{"reshape", {{"X"}, {"Out"}}}, - {"reshape2", {{"X"}, {"Out"}}}, - {"flatten", {{"X"}, {"Out"}}}, - {"flatten2", {{"X"}, {"Out"}}}, - {"squeeze", {{"X"}, {"Out"}}}, - {"squeeze2", {{"X"}, {"Out"}}}, - {"unsqueeze", {{"X"}, {"Out"}}}, - {"unsqueeze2", {{"X"}, {"Out"}}}}; - auto inplace_op_node = inplace_op_nodes.find(op_type); - if (inplace_op_node != inplace_op_nodes.end()) { - bool inplace = false; - if (op_info->HasAttr("inplace")) { - inplace = op_info->GetAttr("inplace"); - } - if (inplace) { - for (auto& in_param_name : inplace_op_node->second.first) { - const auto& in_arg_names = op_info->Input(in_param_name); - invalid_var_names.insert(in_arg_names.begin(), in_arg_names.end()); - } - for (auto& out_param_name : inplace_op_node->second.second) { - const auto& out_arg_names = op_info->Output(out_param_name); - invalid_var_names.insert(out_arg_names.begin(), out_arg_names.end()); - } - } - } } // non-tensor(like tensor_array) variables will not be reused @@ -161,12 +133,35 @@ void XPUMemoryOptimizePass::CollectLifeCycleByDevice( if (op_node->AsStmt().op_info()->Type() == "io_copy_once") { continue; } + + std::map, std::set>> + inplace_ops = {{"reshape", {{"X"}, {"Out"}}}, + {"reshape2", {{"X"}, {"Out"}}}, + {"flatten", {{"X"}, {"Out"}}}, + {"flatten2", {{"X"}, {"Out"}}}, + {"squeeze", {{"X"}, {"Out"}}}, + {"squeeze2", {{"X"}, {"Out"}}}, + {"unsqueeze", {{"X"}, {"Out"}}}, + {"unsqueeze2", {{"X"}, {"Out"}}}}; VLOG(4) << op_node->AsStmt().op_info()->Type() << " life is " << max_lifecycle_; std::vector var_nodes(op_node->inlinks.begin(), op_node->inlinks.end()); var_nodes.insert( var_nodes.end(), op_node->outlinks.begin(), op_node->outlinks.end()); + + int count = 0; + + bool is_inplace = false; + + if (op_node->AsStmt().op_info()->HasAttr("inplace")) { + is_inplace = op_node->AsStmt().op_info()->GetAttr("inplace"); + } + + std::string input_host_var_name = " "; + std::string input_xpu_var_name = " "; + for (auto* var_node : var_nodes) { CHECK(var_node->IsArg()); auto& arg = var_node->AsArg(); @@ -175,18 +170,59 @@ void XPUMemoryOptimizePass::CollectLifeCycleByDevice( VLOG(4) << "OP VAR NAME IS " << var_name; if (var_name.find("_xpu_max") != std::string::npos) continue; if (invalid_var_names.count(var_name)) continue; - TargetType target_type = arg.type->target(); - if (is_host(target_type)) target_type = TARGET(kHost); - - if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { - (*lifecycles)[TargetToStr(target_type)].emplace( - var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); - } else { - int cur_life = - (*lifecycles)[TargetToStr(target_type)][var_name].second; - (*lifecycles)[TargetToStr(target_type)][var_name].second = - (std::max)(max_lifecycle_, cur_life); - } + auto find_inplace_op = + inplace_ops.find(op_node->AsStmt().op_info()->Type()); + + if (find_inplace_op != inplace_ops.end() && count != 2) { + TargetType target_type = arg.type->target(); + if (is_host(target_type)) { + target_type = TARGET(kHost); + continue; + } + + if ((*lifecycles)[TargetToStr(target_type)].count(var_name)) { + if (is_host(target_type)) { + input_host_var_name = var_name; + } else { + input_xpu_var_name = var_name; + count++; + int cur_life = + (*lifecycles)[TargetToStr(target_type)][var_name].second; + (*lifecycles)[TargetToStr(target_type)][var_name].second = + (std::max)(max_lifecycle_, cur_life); + } + } else if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { + count++; + if (is_host(target_type)) { + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, + (*lifecycles)[TargetToStr(target_type)][input_host_var_name]); + } else { + if (is_inplace) { + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); + inplaceop_input2output->emplace(input_xpu_var_name, var_name); + inplaceop_output2input->emplace(var_name, input_xpu_var_name); + } else { + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); + } + } + } + } else if (find_inplace_op == inplace_ops.end()) { + TargetType target_type = arg.type->target(); + if (is_host(target_type)) target_type = TARGET(kHost); + + if (!(*lifecycles)[TargetToStr(target_type)].count(var_name)) { + (*lifecycles)[TargetToStr(target_type)].emplace( + var_name, std::make_pair(max_lifecycle_, max_lifecycle_)); + } else { + int cur_life = + (*lifecycles)[TargetToStr(target_type)][var_name].second; + (*lifecycles)[TargetToStr(target_type)][var_name].second = + (std::max)(max_lifecycle_, cur_life); + } + } // if else } ++max_lifecycle_; } @@ -196,7 +232,9 @@ void XPUMemoryOptimizePass::CollectLifeCycleByDevice( void XPUMemoryOptimizePass::MakeReusePlan( const lifecycle_map_t& lifecycles, - std::map* node2cluster) { + std::map* node2cluster, + std::map* inplaceop_input2output, + std::map* inplaceop_output2input) { std::vector mem_nodes; std::vector cluster; for (auto& data : lifecycles) { @@ -204,6 +242,7 @@ void XPUMemoryOptimizePass::MakeReusePlan( temp_node.name = data.first; temp_node.cluster = -1; temp_node.lifetime = data.second; + temp_node.mapping = 0; temp_node.life_interval = data.second.second - data.second.first; mem_nodes.push_back(temp_node); } @@ -234,33 +273,129 @@ void XPUMemoryOptimizePass::MakeReusePlan( } } } + VLOG(4) << "Step1 get inplace node Cluster: "; + for (size_t i = 0; i < mem_nodes.size(); i++) { + if (inplaceop_input2output->count(mem_nodes[i].name)) { + int cluster_index = cluster.size(); + mem_nodes[i].cluster = cluster_index; + (*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name; + VLOG(4) << "Mapping Tensor Cluster: " << mem_nodes[i].name + << ", life time is " << mem_nodes[i].lifetime.first << " --> " + << mem_nodes[i].lifetime.second << ", cluster name is " + << (*node2cluster)[mem_nodes[i].name]; + std::set cluster_adj = mem_nodes[i].adj; + for (size_t j = 0; j < mem_nodes.size(); j++) { + if (mem_nodes[j].name == (*inplaceop_input2output)[mem_nodes[i].name]) { + (*node2cluster)[mem_nodes[j].name] == mem_nodes[i].name; + mem_nodes[j].cluster = cluster_index; + VLOG(4) << mem_nodes[j].name << ", life time is " + << mem_nodes[j].lifetime.first << " --> " + << mem_nodes[j].lifetime.second << ", cluster name is " + << (*node2cluster)[mem_nodes[j].name]; + for (auto& n : mem_nodes[j].adj) { + cluster_adj.insert(n); + } + } + } + } + } + VLOG(4) << "Step2 merge inplace node Cluster: "; + for (size_t i = 0; i < mem_nodes.size(); i++) { + if (inplaceop_input2output->count(mem_nodes[i].name) && + mem_nodes[i].mapping != 1) { + int cluster_index = cluster.size(); + mem_nodes[i].cluster = cluster_index; + (*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name; + mem_nodes[i].mapping = 1; + VLOG(4) << "Mapping Tensor Cluster: " << mem_nodes[i].name + << ", life time is " << mem_nodes[i].lifetime.first << " --> " + << mem_nodes[i].lifetime.second << ", cluster index is " + << mem_nodes[i].cluster << ", cluster name is " + << (*node2cluster)[mem_nodes[i].name]; + cluster.push_back(mem_nodes[i].name); + + std::set cluster_adj = mem_nodes[i].adj; + for (size_t j = 0; j < mem_nodes.size(); j++) { + if (mem_nodes[j].name == (*inplaceop_input2output)[mem_nodes[i].name]) { + mem_nodes[j].cluster = mem_nodes[i].cluster; + (*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name; + VLOG(4) << mem_nodes[j].name << ", life time is " + << mem_nodes[j].lifetime.first << " --> " + << mem_nodes[j].lifetime.second << ", cluster index is " + << mem_nodes[j].cluster << ", cluster name is " + << (*node2cluster)[mem_nodes[j].name]; + + for (auto& m : mem_nodes[j].adj) { + cluster_adj.insert(m); + } + } else if (inplaceop_input2output->count(mem_nodes[j].name) && + (cluster_adj.find(mem_nodes[j].name) == cluster_adj.end()) && + mem_nodes[j].mapping != 1) { + mem_nodes[j].mapping = 1; + mem_nodes[j].cluster = mem_nodes[i].cluster; + (*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name; + VLOG(4) << mem_nodes[j].name << ", life time is " + << mem_nodes[j].lifetime.first << " --> " + << mem_nodes[j].lifetime.second << ", cluster index is " + << mem_nodes[j].cluster << ", cluster name is " + << (*node2cluster)[mem_nodes[j].name]; + + for (auto& n : mem_nodes[j].adj) { + cluster_adj.insert(n); + } + for (size_t n = 0; n < mem_nodes.size(); n++) { + if (mem_nodes[n].name == + (*inplaceop_input2output)[mem_nodes[j].name]) { + mem_nodes[n].cluster = mem_nodes[i].cluster; + (*node2cluster)[mem_nodes[n].name] = mem_nodes[i].name; + VLOG(4) << mem_nodes[n].name << ", life time is " + << mem_nodes[n].lifetime.first << " --> " + << mem_nodes[n].lifetime.second << ", cluster index is " + << mem_nodes[n].cluster << ", cluster name is " + << (*node2cluster)[mem_nodes[n].name]; - // Generating XPUMemory Reuse Strategy Based on Greedy Way - // The vars can be reused if there is no overlap between them. + for (auto& m : mem_nodes[n].adj) { + cluster_adj.insert(m); + } + } + } + } + } + } + } + VLOG(4) << "Step3 get others node Cluster : "; for (size_t i = 0; i < mem_nodes.size(); i++) { - if (mem_nodes[i].cluster >= 0 || mem_nodes[i].life_interval == 0) continue; - int cluster_index = cluster.size(); - mem_nodes[i].cluster = cluster_index; - (*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name; - VLOG(4) << "Mapping Tensor Cluster: " << mem_nodes[i].name - << ", life time is " << mem_nodes[i].lifetime.first << " --> " - << mem_nodes[i].lifetime.second; - cluster.push_back(mem_nodes[i].name); - std::set cluster_adj = mem_nodes[i].adj; - for (size_t j = i + 1; j < mem_nodes.size(); j++) { - if (mem_nodes[j].cluster < 0 && - (cluster_adj.find(mem_nodes[j].name) == cluster_adj.end())) { - (*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name; - mem_nodes[j].cluster = cluster_index; - VLOG(4) << mem_nodes[j].name << ", life time is " - << mem_nodes[j].lifetime.first << " --> " - << mem_nodes[j].lifetime.second; - for (auto& n : mem_nodes[j].adj) { - cluster_adj.insert(n); + if (!(inplaceop_input2output->count(mem_nodes[i].name)) && + mem_nodes[i].cluster < 0 && mem_nodes[i].life_interval != 0) { + int cluster_index = cluster.size(); + mem_nodes[i].cluster = cluster_index; + (*node2cluster)[mem_nodes[i].name] = mem_nodes[i].name; + VLOG(4) << "Mapping Tensor Cluster: " << mem_nodes[i].name + << ", life time is " << mem_nodes[i].lifetime.first << " --> " + << mem_nodes[i].lifetime.second << ", cluster index is " + << mem_nodes[i].cluster << ", cluster name is " + << (*node2cluster)[mem_nodes[i].name]; + cluster.push_back(mem_nodes[i].name); + std::set cluster_adj = mem_nodes[i].adj; + for (size_t j = i + 1; j < mem_nodes.size(); j++) { + if (!(inplaceop_input2output->count(mem_nodes[j].name)) && + mem_nodes[j].cluster < 0 && + (cluster_adj.find(mem_nodes[j].name) == cluster_adj.end())) { + mem_nodes[j].cluster = mem_nodes[i].cluster; + (*node2cluster)[mem_nodes[j].name] = mem_nodes[i].name; + VLOG(4) << mem_nodes[j].name << ", life time is " + << mem_nodes[j].lifetime.first << " --> " + << mem_nodes[j].lifetime.second << ", cluster index is " + << mem_nodes[j].cluster << ", cluster name is " + << (*node2cluster)[mem_nodes[j].name]; + for (auto& n : mem_nodes[j].adj) { + cluster_adj.insert(n); + } } } } } + for (auto& name : cluster) { LOG(INFO) << "cluster: " << name; } @@ -272,6 +407,7 @@ void XPUMemoryOptimizePass::PerformReusePlan( for (auto& op_node : graph->StmtTopologicalOrder()) { if (!op_node->IsStmt()) continue; auto& stmt = op_node->AsStmt(); + auto* op_info = stmt.mutable_op_info(); std::map> in_args, out_args; // replace the op's input according the reuse table. @@ -354,13 +490,21 @@ void XPUMemoryOptimizePass::Apply(const std::unique_ptr& graph) { // 3. Perform reuse plan: Replace all var's name in the model according to the // mapping table. std::map lifecycles; - CollectLifeCycleByDevice(&lifecycles, graph.get()); + std::map inplaceop_input2output; + std::map inplaceop_output2input; + CollectLifeCycleByDevice(&lifecycles, + graph.get(), + &inplaceop_input2output, + &inplaceop_output2input); for (auto& ele : lifecycles) { if (ele.first != "xpu") { continue; } std::map node2cluster; - MakeReusePlan(ele.second, &node2cluster); + MakeReusePlan(ele.second, + &node2cluster, + &inplaceop_input2output, + &inplaceop_output2input); PerformReusePlan(graph.get(), node2cluster); } } diff --git a/lite/core/optimizer/mir/xpu_memory_optimize_pass.h b/lite/core/optimizer/mir/xpu_memory_optimize_pass.h index f0d920fadf3..d4bbf9e7f9d 100644 --- a/lite/core/optimizer/mir/xpu_memory_optimize_pass.h +++ b/lite/core/optimizer/mir/xpu_memory_optimize_pass.h @@ -31,9 +31,6 @@ namespace paddle { namespace lite { namespace mir { -/* - * XPUMemoryOptimizePass will - */ class XPUMemoryOptimizePass : public ProgramPass { public: using lifecycle_t = std::pair; @@ -42,9 +39,15 @@ class XPUMemoryOptimizePass : public ProgramPass { private: void CollectLifeCycleByDevice( - std::map* lifecycles, SSAGraph*); - void MakeReusePlan(const lifecycle_map_t& lifecycles, - std::map* node2cluster); + std::map* lifecycles, + SSAGraph*, + std::map* inplaceop_input2output, + std::map* inplaceop_output2input); + void MakeReusePlan( + const lifecycle_map_t& lifecycles, + std::map* node2cluster, + std::map* inplaceop_input2output, + std::map* inplaceop_output2input); void PerformReusePlan(SSAGraph* graph, const std::map& reuse_table); diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt index ceabed4fa7a..aab54f71698 100644 --- a/lite/kernels/xpu/CMakeLists.txt +++ b/lite/kernels/xpu/CMakeLists.txt @@ -93,6 +93,7 @@ else() add_kernel(anchor_generator_compute_xpu XPU extra SRCS anchor_generator_compute.cc) add_kernel(box_clip_compute_xpu XPU extra SRCS box_clip_compute.cc) add_kernel(pad2d_compute_xpu XPU extra SRCS pad2d_compute.cc) + add_kernel(pad3d_compute_xpu XPU extra SRCS pad3d_compute.cc) add_kernel(pixel_shuffle_compute_xpu XPU extra SRCS pixel_shuffle_compute.cc) add_kernel(correlation_compute_xpu XPU extra SRCS correlation_compute.cc) add_kernel(logical_compute_xpu XPU extra SRCS logical_compute.cc) diff --git a/lite/kernels/xpu/pad3d_compute.cc b/lite/kernels/xpu/pad3d_compute.cc new file mode 100644 index 00000000000..261090faded --- /dev/null +++ b/lite/kernels/xpu/pad3d_compute.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/pad3d_compute.h" +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +void Pad3dCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + auto pads = param.paddings; + auto mode = param.mode; + auto data_format = param.data_format; + T value = static_cast(param.pad_value); + + auto* x = param.X; + auto in_dims = x->dims(); + auto* in_data = x->template data(); + auto* out = param.Out; + T* out_data = out->template mutable_data(TARGET(kXPU)); + + if (mode == "reflect" || mode == "constant" || mode == "replicate" || + mode == "circular") { + if (data_format == "NCDHW") { + std::vector pad_left = {0, 0, pads[4], pads[2], pads[0]}; + std::vector pad_right = {0, 0, pads[5], pads[3], pads[1]}; + + int n_shape = in_dims[0]; + int c_shape = in_dims[1]; + int d_shape = in_dims[2]; + int h_shape = in_dims[3]; + int w_shape = in_dims[4]; + + std::vector xshape = {n_shape, c_shape, d_shape, h_shape, w_shape}; + + int r = xdnn::pad(ctx.GetRawContext(), + in_data, + out_data, + xshape, + pad_left, + pad_right, + value); + CHECK_EQ(r, 0); + } else if (data_format == "NDHWC") { + std::vector pad_left = {0, pads[4], pads[2], pads[0], 0}; + std::vector pad_right = {0, pads[5], pads[3], pads[1], 0}; + + int n_shape = in_dims[0]; + int d_shape = in_dims[1]; + int h_shape = in_dims[2]; + int w_shape = in_dims[3]; + int c_shape = in_dims[4]; + std::vector xshape = {n_shape, d_shape, h_shape, w_shape, c_shape}; + + int r = xdnn::pad(ctx.GetRawContext(), + in_data, + out_data, + xshape, + pad_left, + pad_right, + value); + CHECK_EQ(r, 0); + } + + } else { + LOG(FATAL) << "xpu unsupport mode: " << mode; + } +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pad3d, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::Pad3dCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/pad3d_compute.h b/lite/kernels/xpu/pad3d_compute.h new file mode 100644 index 00000000000..734e01fde5b --- /dev/null +++ b/lite/kernels/xpu/pad3d_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +class Pad3dCompute : public KernelLite { + public: + using param_t = operators::Pad2dParam; + + virtual void Run(); + + virtual ~Pad3dCompute() = default; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle