Skip to content

Commit

Permalink
add gemm_prepack_oth_int8 support GemmNBias test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
sprouteer committed Jan 10, 2023
1 parent 8fa178a commit 8468d87
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 67 deletions.
210 changes: 168 additions & 42 deletions lite/backends/arm/math/gemm_prepacked_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
const float32_t* alpha,
int is_relu,
int k,
int rem);
int rem,
int bias_direction);
// clang-format off
#ifdef __aarch64__
#define GEMM_INT8_KERNEL \
Expand Down Expand Up @@ -802,9 +803,70 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"fmla v28.4s, v4.4s, v15.s[3]\n" /* 30, mul scale */ \
"fmla v29.4s, v5.4s, v15.s[3]\n" /* 31, mul scale */ \
"fmla v30.4s, v6.4s, v15.s[3]\n" /* 32, mul scale */ \
"fmla v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale */
"fmla v31.4s, v7.4s, v15.s[3]\n" /* 33, mul scale */ \
"8: \n"

#define GEMM_TRANS_INT32_TO_FP32_N_Direction \
"cmp %w[bias_direction], #2\n" /* skip N_Direction */ \
"bne 7f\n" /* skip N_Direction */ \
"ldp q8, q9, [%[bias]]\n" /* load bias */ \
"ldp q10, q11, [%[bias], #32]\n" /* load bias */ \
"ldp q12, q13, [%[scale]]\n" /* load scale */ \
"ldp q14, q15, [%[scale], #32]\n" /* load scale */ \
"scvtf v0.4s , v16.4s\n" /* 00, convert to fp32 */ \
"scvtf v1.4s , v17.4s\n" /* 01, convert to fp32 */ \
"scvtf v2.4s , v18.4s\n" /* 02, convert to fp32 */ \
"scvtf v3.4s , v19.4s\n" /* 03, convert to fp32 */ \
"scvtf v4.4s , v20.4s\n" /* 10, convert to fp32 */ \
"scvtf v5.4s , v21.4s\n" /* 11, convert to fp32 */ \
"scvtf v6.4s , v22.4s\n" /* 12, convert to fp32 */ \
"scvtf v7.4s , v23.4s\n" /* 13, convert to fp32 */ \
/* add bias */ \
"mov v16.4s, v8.4s\n" \
"mov v17.4s, v9.4s\n" \
"mov v18.4s, v10.4s\n" \
"mov v19.4s, v11.4s\n" \
"mov v20.4s, v8.4s\n" \
"mov v21.4s, v9.4s\n" \
"mov v22.4s, v10.4s\n" \
"mov v23.4s, v11.4s\n" \
"fmla v16.4s, v0.4s, v12.4s\n" /* 00, mul scale */ \
"fmla v17.4s, v1.4s, v13.4s\n" /* 01, mul scale */ \
"fmla v18.4s, v2.4s, v14.4s\n" /* 02, mul scale */ \
"fmla v19.4s, v3.4s, v15.4s\n" /* 03, mul scale */ \
"fmla v20.4s, v4.4s, v12.4s\n" /* 10, mul scale */ \
"fmla v21.4s, v5.4s, v13.4s\n" /* 11, mul scale */ \
"fmla v22.4s, v6.4s, v14.4s\n" /* 12, mul scale */ \
"fmla v23.4s, v7.4s, v15.4s\n" /* 13, mul scale */ \
"scvtf v0.4s , v24.4s\n" /* 20, convert to fp32 */ \
"scvtf v1.4s , v25.4s\n" /* 21, convert to fp32 */ \
"scvtf v2.4s , v26.4s\n" /* 22, convert to fp32 */ \
"scvtf v3.4s , v27.4s\n" /* 23, convert to fp32 */ \
"scvtf v4.4s , v28.4s\n" /* 30, convert to fp32 */ \
"scvtf v5.4s , v29.4s\n" /* 31, convert to fp32 */ \
"scvtf v6.4s , v30.4s\n" /* 32, convert to fp32 */ \
"scvtf v7.4s , v31.4s\n" /* 33, convert to fp32 */ \
"mov v24.4s, v8.4s\n" \
"mov v25.4s, v9.4s\n" \
"mov v26.4s, v10.4s\n" \
"mov v27.4s, v11.4s\n" \
"mov v28.4s, v8.4s\n" \
"mov v29.4s, v9.4s\n" \
"mov v30.4s, v10.4s\n" \
"mov v31.4s, v11.4s\n" \
"fmla v24.4s, v0.4s, v12.4s\n" /* 20, mul scale */ \
"fmla v25.4s, v1.4s, v13.4s\n" /* 21, mul scale */ \
"fmla v26.4s, v2.4s, v14.4s\n" /* 22, mul scale */ \
"fmla v27.4s, v3.4s, v15.4s\n" /* 23, mul scale */ \
"fmla v28.4s, v4.4s, v12.4s\n" /* 30, mul scale */ \
"fmla v29.4s, v5.4s, v13.4s\n" /* 31, mul scale */ \
"fmla v30.4s, v6.4s, v14.4s\n" /* 32, mul scale */ \
"fmla v31.4s, v7.4s, v15.4s\n" /* 33, mul scale */ \
"b 8f \n" \
"7: \n"

#define GEMM_INT8_FP32_OUT \
GEMM_TRANS_INT32_TO_FP32_N_Direction \
GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
Expand All @@ -821,6 +883,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"stp q30, q31, [%[c_ptr3]], #32\n"

#define GEMM_INT8_INT8_OUT \
GEMM_TRANS_INT32_TO_FP32_N_Direction \
GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
Expand Down Expand Up @@ -933,7 +996,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
const float32_t* alpha,
int is_relu,
int k,
int rem) {
int rem,
int bias_direction) {
// clang-format off
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_FP32_OUT
: [a_ptr] "+r"(a_ptr),
Expand All @@ -947,7 +1011,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[alpha] "r"(alpha),
[bias] "r"(bias),
[rem] "r"(rem),
[scale] "r"(scale)
[scale] "r"(scale),
[bias_direction] "r"(bias_direction)
: "v0","v1","v2","v3","v4","v5","v6","v7","v8",
"v9","v10","v11","v12","v13","v14",
"v15","v16","v17","v18","v19","v20",
Expand All @@ -968,7 +1033,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
const float32_t* alpha,
int is_relu,
int k,
int rem) {
int rem,
int bias_direction) {
// clang-format off
float vmax[4] = {-127.0, -127.0, -127.0, -127.0};
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT8_OUT
Expand All @@ -984,7 +1050,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[bias] "r"(bias),
[rem] "r"(rem),
[scale] "r"(scale),
[vmax] "r"(vmax)
[vmax] "r"(vmax),
[bias_direction] "r"(bias_direction)
: "v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11","v12",
"v13","v14","v15","v16","v17",
Expand All @@ -1006,7 +1073,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
const float32_t* alpha,
int is_relu,
int k,
int rem) {
int rem,
int bias_direction) {
// clang-format off
float vmax[4] = {-127.0, -127.0, -127.0, -127.0};
asm volatile(GEMM_INT8_KERNEL GEMM_INT8_INT32_OUT
Expand All @@ -1022,7 +1090,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[bias] "r"(bias),
[rem] "r"(rem),
[scale] "r"(scale),
[vmax] "r"(vmax)
[vmax] "r"(vmax),
[bias_direction] "r"(bias_direction)
: "v0","v1","v2","v3","v4","v5","v6","v7",
"v8","v9","v10","v11","v12",
"v13","v14","v15","v16","v17",
Expand Down Expand Up @@ -4016,8 +4085,43 @@ inline void gemm_dot_int8_kernel(const int8_t* a_ptr,
"vmla.f32 q2, q10, d13[0]\n" /* r20, mul scale */ \
"vmla.f32 q3, q11, d13[0]\n" /* r21, mul scale */ \
"vmla.f32 q4, q12, d13[1]\n" /* r30, mul scale */ \
"vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale */
"vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale */ \
"8: \n"

#define GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
"cmp %[bias_direction], #2\n" /* skip N_Direction */ \
"bne 7f\n" /* skip N_Direction */ \
/* write output */ \
"vld1.32 {d12-d13}, [%[scale]]!\n" /* load scale */ \
"vld1.32 {d14-d15}, [%[bias]]!\n" /* load bias */ \
"vcvt.f32.s32 q10, q8\n" /* r00, cvt int32 to fp32*/ \
"vcvt.f32.s32 q12, q0\n" /* r10, cvt int32 to fp32*/ \
"vmov.32 q8, q6\n" \
"vmov.32 q0, q6\n" \
"vmla.f32 q8, q10, q7\n" /* r00, mul scale */ \
"vmla.f32 q0, q12, q7\n" /* r10, mul scale */ \
"vcvt.f32.s32 q10, q2\n" /* r20, cvt int32 to fp32*/ \
"vcvt.f32.s32 q12, q4\n" /* r30, cvt int32 to fp32*/ \
"vdup.32 q2, d15[0]\n" \
"vdup.32 q4, d15[1]\n" \
"vmla.f32 q2, q10, d13[0]\n" /* r20, mul scale */ \
"vmla.f32 q4, q12, d13[1]\n" /* r30, mul scale */ \
"vld1.32 {d12-d13}, [%[scale]]\n" /* load scale */ \
"vld1.32 {d14-d15}, [%[bias]]\n" /* load bias */ \
"vcvt.f32.s32 q11, q9\n" /* r01, cvt int32 to fp32*/ \
"vcvt.f32.s32 q13, q1\n" /* r11, cvt int32 to fp32*/ \
"vmov.32 q9, q6\n" \
"vmov.32 q1, q6\n" \
"vmla.f32 q9, q11, q7\n" /* r01, mul scale */ \
"vmla.f32 q1, q13, q7\n" /* r11, mul scale */ \
"vcvt.f32.s32 q11, q3\n" /* r21, cvt int32 to fp32*/ \
"vcvt.f32.s32 q13, q5\n" /* r31, cvt int32 to fp32*/ \
"vdup.32 q3, d15[0]\n" \
"vdup.32 q5, d15[1]\n" \
"vmla.f32 q3, q11, d13[0]\n" /* r21, mul scale */ \
"vmla.f32 q5, q13, d13[1]\n" /* r31, mul scale */ \
"b 8f \n" \
"7: \n"

#define GEMM_INT8_RELU \
/* do relu */ \
Expand Down Expand Up @@ -4141,19 +4245,21 @@ inline void gemm_dot_int8_kernel(const int8_t* a_ptr,
"vmul.f32 q5, q5, q11 \n" \
"9: \n"

#define GEMM_INT8_FP32_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
GEMM_INT8_HARD_SWISH \
#define GEMM_INT8_FP32_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
GEMM_INT8_HARD_SWISH \
"vst1.32 {d16-d19}, [%[c_ptr0]]!\n" /* write r0, float32x4 x2 */ \
"vst1.32 {d0-d3}, [%[c_ptr1]]!\n" /* write r1, float32x4 x2 */ \
"vst1.32 {d4-d7}, [%[c_ptr2]]!\n" /* write r2, float32x4 x2 */ \
"vst1.32 {d8-d11}, [%[c_ptr3]]!\n" /* write r3, float32x4 x2 */


#define GEMM_INT8_INT8_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32_N_Direction \
GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
Expand Down Expand Up @@ -4257,7 +4363,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
const float32_t* alpha,
int is_relu,
int k,
int rem) {
int rem,
int bias_direction) {
float new_ptr[16] = {alpha[0],
alpha[1],
alpha[2],
Expand Down Expand Up @@ -4287,7 +4394,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[bias] "r"(bias),
[alpha] "r"(new_ptr),
[rem] "r"(rem),
[scale] "r"(scale)
[scale] "r"(scale),
[bias_direction] "r"(bias_direction)
: "q0",
"q1",
"q2",
Expand Down Expand Up @@ -4320,7 +4428,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
const float32_t* alpha,
int is_relu,
int k,
int rem) {
int rem,
int bias_direction) {
float new_ptr[16] = {alpha[0],
alpha[1],
alpha[2],
Expand Down Expand Up @@ -4350,7 +4459,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[alpha] "r"(new_ptr),
[bias] "r"(bias),
[rem] "r"(rem),
[scale] "r"(scale)
[scale] "r"(scale),
[bias_direction] "r"(bias_direction)
: "q0",
"q1",
"q2",
Expand Down Expand Up @@ -4384,7 +4494,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
const float32_t* alpha,
int is_relu,
int k,
int rem) {
int rem,
int bias_direction) {
float new_ptr[16] = {alpha[0],
alpha[1],
alpha[2],
Expand Down Expand Up @@ -4511,24 +4622,27 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
Dtype* tmp1 = nullptr;
Dtype* tmp2 = nullptr;
Dtype* tmp3 = nullptr;
float32_t scale_local[4] = {0, 0, 0, 0};
float32_t bias_local[4] = {0, 0, 0, 0};
if (is_bias) {
if (y + 4 <= M) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
} else {
switch (M - y) {
case 3:
bias_local[2] = bias[y + 2];
case 2:
bias_local[1] = bias[y + 1];
case 1:
bias_local[0] = bias[y + 0];
default:
break;
float32_t scale_local[16] = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float32_t bias_local[16] = {0};
if (bias_direction != GemmNBias) {
if (is_bias) {
if (y + 4 <= M) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
} else {
switch (M - y) {
case 3:
bias_local[2] = bias[y + 2];
case 2:
bias_local[1] = bias[y + 1];
case 1:
bias_local[0] = bias[y + 0];
default:
break;
}
}
}
}
Expand Down Expand Up @@ -4566,6 +4680,18 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
const int8_t* a_ptr_l = A_packed + y * KUP;
const int8_t* b_ptr = b_pannel;
for (int xb = 0; xb < bblocks; xb++) {
if (bias_direction == GemmNBias) {
if (scale) {
for (int j = 0; j < NBLOCK_INT8_OTH; j++) {
scale_local[j] = scale[xb * NBLOCK_INT8_OTH + j + x0];
}
}
if (bias) {
for (int j = 0; j < NBLOCK_INT8_OTH; j++) {
bias_local[j] = bias[xb * NBLOCK_INT8_OTH + j + x0];
}
}
}
if (flag_rem && (xb == bblocks - 1)) {
tmp0 = c_ptr0;
tmp1 = c_ptr1;
Expand All @@ -4587,7 +4713,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
alpha,
flag_act,
k,
k_rem);
k_rem,
bias_direction);
if (flag_rem && (xb == bblocks - 1)) {
for (int i = 0; i < n_rem; ++i) {
*(tmp0++) = out0[i];
Expand Down Expand Up @@ -7994,9 +8121,8 @@ GEMM_PREPACK_INT8(float_t);
GEMM_PREPACK_INT8(int32_t);

#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
#define IN_PARAMS_NO_BIAS_DIRECTION \
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, \
scale, alpha, ctx
#define IN_PARAMS_NO_BIAS_DIRECTION \
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx
template <typename dtype>
void gemm_prepack_int8_nopack(const int8_t* A_packed,
const int8_t* B,
Expand Down
Loading

0 comments on commit 8468d87

Please sign in to comment.