Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[arm] add fp16 mul implentation #8408

Merged
merged 2 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lite/core/optimizer/mir/fp16_attribute_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class FP16AttributePass : public ProgramPass {
"elementwise_mul",
"elementwise_div",
"elementwise_sub",
"mul",
"prelu"};
};

Expand Down
93 changes: 93 additions & 0 deletions lite/kernels/arm/mul_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#ifdef ENABLE_ARM_FP16
#include "lite/backends/arm/math/fp16/funcs_fp16.h"
#endif

namespace paddle {
namespace lite {
Expand Down Expand Up @@ -195,15 +198,105 @@ void MulCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
}
mul_add_n_scale_bias(o_data, scale_.data(), m_, n_);
}
#ifdef ENABLE_ARM_FP16
template <>
void MulCompute<PRECISION(kFP16), PRECISION(kFP16)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
}

template <>
void MulCompute<PRECISION(kFP16), PRECISION(kFP16)>::Run() {
auto& param = Param<param_t>();

const auto* x_data = param.x->data<float16_t>();
const auto* y_data = param.y->data<float16_t>();
auto* o_data = param.output->mutable_data<float16_t>();

m_ = static_cast<int>(
param.x->dims().Slice(0, param.x_num_col_dims).production());
int x_w =
static_cast<int>(param.x->dims()
.Slice(param.x_num_col_dims, param.x->dims().size())
.production());
int y_h = static_cast<int>(
param.y->dims().Slice(0, param.y_num_col_dims).production());
n_ = static_cast<int>(param.y->dims()
.Slice(param.y_num_col_dims, param.y->dims().size())
.production());

CHECK_EQ(x_w, y_h) << "x_w must be equal with y_h";
k_ = x_w;
auto& ctx = this->ctx_->template As<ARMContext>();
operators::ActivationParam act_param;
act_param.has_active = false;
if (n_ == 1) {
lite::arm::math::fp16::gemv_fp16(x_data,
y_data,
o_data,
false,
m_,
k_,
0.f,
false,
nullptr,
act_param.has_active,
act_param,
&ctx);

} else {
constexpr bool is_tranposed_y = false;
int hblock = lite::arm::math::get_hblock(&ctx, m_);
int m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float16_t));

float16_t* packed_x =
static_cast<float16_t*>(ctx.workspace_data<float16_t>()) +
ctx.llc_size() / sizeof(float16_t);
lite::arm::math::fp16::prepackA_fp16(
packed_x, x_data, 1.f, k_, 0, m_, 0, k_, false, &ctx);
int ldb = n_;
if (is_tranposed_y) {
ldb = k_;
}
lite::arm::math::fp16::gemm_prepack_fp16(is_tranposed_y,
m_,
n_,
k_,
packed_x,
y_data,
ldb,
0.f,
o_data,
n_,
nullptr,
false,
act_param,
&ctx);
}
}
#endif

} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle

typedef paddle::lite::kernels::arm::MulCompute<PRECISION(kFloat),
PRECISION(kFloat)>
Mul_f32_f32;

#ifdef ENABLE_ARM_FP16
typedef paddle::lite::kernels::arm::MulCompute<PRECISION(kFP16),
PRECISION(kFP16)>
Mul_f16_f16;
REGISTER_LITE_KERNEL(mul, kARM, kFP16, kNCHW, Mul_f16_f16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFP16))})
.Finalize();

#endif // ENABLE_ARM_FP16

REGISTER_LITE_KERNEL(mul, kARM, kFloat, kNCHW, Mul_f32_f32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
Expand Down
28 changes: 21 additions & 7 deletions lite/tests/unittest_py/op/test_mul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,21 @@
class TestMulOp(AutoScanTest):
def __init__(self, *args, **kwargs):
AutoScanTest.__init__(self, *args, **kwargs)
self.enable_testing_on_place(TargetType.ARM, PrecisionType.FP32,
DataLayoutType.NCHW)
self.enable_testing_on_place(TargetType.X86, PrecisionType.FP32,
DataLayoutType.NCHW)
self.enable_testing_on_place(
TargetType.ARM,
PrecisionType.FP32,
DataLayoutType.NCHW,
thread=[1, 4])
self.enable_testing_on_place(
TargetType.ARM,
PrecisionType.FP16,
DataLayoutType.NCHW,
thread=[1, 4])
self.enable_testing_on_place(
TargetType.X86,
PrecisionType.FP32,
DataLayoutType.NCHW,
thread=[1, 4])

def is_program_valid(self,
program_config: ProgramConfig,
Expand Down Expand Up @@ -90,8 +101,11 @@ def sample_program_configs(self, draw):

program_config = ProgramConfig(
ops=[mul_op],
weights={"input_data_y": TensorConfig(shape=Y_shape)},
inputs={"input_data_x": TensorConfig(shape=X_shape)},
weights={},
inputs={
"input_data_x": TensorConfig(shape=X_shape),
"input_data_y": TensorConfig(shape=Y_shape)
},
outputs=["output_data"])

return program_config
Expand All @@ -103,7 +117,7 @@ def add_ignore_pass_case(self):
pass

def test(self, *args, **kwargs):
self.run_and_statis(quant=False, max_examples=25)
self.run_and_statis(quant=False, max_examples=250)


if __name__ == "__main__":
Expand Down