From 9298d05e86baf59a0811d9683aff6e44749e9f5e Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 10 Apr 2022 17:53:34 +0800 Subject: [PATCH] split convolution winograd transform input output (#3688) --- src/layer/arm/convolution_3x3_pack4.h | 590 +---------------- src/layer/arm/convolution_3x3_pack4_bf16s.h | 595 +----------------- src/layer/arm/convolution_3x3_pack4_fp16s.h | 377 +---------- src/layer/arm/convolution_3x3_pack4to1.h | 310 +-------- .../arm/convolution_3x3_pack4to1_bf16s.h | 315 +--------- src/layer/arm/convolution_3x3_pack8_fp16s.h | 595 +----------------- .../arm/convolution_3x3_pack8to1_fp16s.h | 315 +--------- .../arm/convolution_3x3_pack8to4_fp16s.h | 377 +---------- src/layer/arm/convolution_arm.cpp | 7 + .../arm/convolution_winograd_transform.h | 125 ++++ .../convolution_winograd_transform_bf16s.h | 125 ++++ .../convolution_winograd_transform_fp16s.h | 125 ++++ .../convolution_winograd_transform_pack4.h | 535 ++++++++++++++++ ...nvolution_winograd_transform_pack4_bf16s.h | 535 ++++++++++++++++ ...nvolution_winograd_transform_pack4_fp16s.h | 313 +++++++++ ...nvolution_winograd_transform_pack8_fp16s.h | 535 ++++++++++++++++ src/layer/mips/convolution_3x3_pack4.h | 544 +--------------- src/layer/mips/convolution_mips.cpp | 1 + .../convolution_winograd_transform_pack4.h | 560 +++++++++++++++++ src/layer/riscv/convolution_3x3_packn.h | 523 +-------------- src/layer/riscv/convolution_3x3_packn_fp16s.h | 523 +-------------- src/layer/riscv/convolution_riscv.cpp | 2 + .../convolution_winograd_transform_packn.h | 551 ++++++++++++++++ ...nvolution_winograd_transform_packn_fp16s.h | 551 ++++++++++++++++ src/layer/x86/convolution_3x3_pack16.h | 535 +--------------- src/layer/x86/convolution_3x3_pack4.h | 564 +---------------- src/layer/x86/convolution_3x3_pack4to1.h | 274 +------- src/layer/x86/convolution_3x3_pack8.h | 535 +--------------- src/layer/x86/convolution_3x3_pack8to1.h | 272 +------- .../x86/convolution_winograd_transform.h | 125 ++++ .../convolution_winograd_transform_pack16.h | 555 ++++++++++++++++ .../convolution_winograd_transform_pack4.h | 580 +++++++++++++++++ .../convolution_winograd_transform_pack8.h | 555 ++++++++++++++++ src/layer/x86/convolution_x86.cpp | 4 + 34 files changed, 5955 insertions(+), 7073 deletions(-) create mode 100644 src/layer/arm/convolution_winograd_transform.h create mode 100644 src/layer/arm/convolution_winograd_transform_bf16s.h create mode 100644 src/layer/arm/convolution_winograd_transform_fp16s.h create mode 100644 src/layer/arm/convolution_winograd_transform_pack4.h create mode 100644 src/layer/arm/convolution_winograd_transform_pack4_bf16s.h create mode 100644 src/layer/arm/convolution_winograd_transform_pack4_fp16s.h create mode 100644 src/layer/arm/convolution_winograd_transform_pack8_fp16s.h create mode 100644 src/layer/mips/convolution_winograd_transform_pack4.h create mode 100644 src/layer/riscv/convolution_winograd_transform_packn.h create mode 100644 src/layer/riscv/convolution_winograd_transform_packn_fp16s.h create mode 100644 src/layer/x86/convolution_winograd_transform.h create mode 100644 src/layer/x86/convolution_winograd_transform_pack16.h create mode 100644 src/layer/x86/convolution_winograd_transform_pack4.h create mode 100644 src/layer/x86/convolution_winograd_transform_pack8.h diff --git a/src/layer/arm/convolution_3x3_pack4.h b/src/layer/arm/convolution_3x3_pack4.h index a14269803af..779d034cfc5 100644 --- a/src/layer/arm/convolution_3x3_pack4.h +++ b/src/layer/arm/convolution_3x3_pack4.h @@ -244,7 +244,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack4_neon(const Mat& kernel, } } -static void conv3x3s1_winograd64_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -266,209 +266,15 @@ static void conv3x3s1_winograd64_pack4_neon(const Mat& bottom_blob, Mat& top_blo h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + int tiles = w_tiles * h_tiles; bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[8][8][4]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - float32x4_t _r00 = vld1q_f32(r0); - float32x4_t _r01 = vld1q_f32(r0 + 4); - float32x4_t _r02 = vld1q_f32(r0 + 8); - float32x4_t _r03 = vld1q_f32(r0 + 12); - float32x4_t _r04 = vld1q_f32(r0 + 16); - float32x4_t _r05 = vld1q_f32(r0 + 20); - float32x4_t _r06 = vld1q_f32(r0 + 24); - float32x4_t _r07 = vld1q_f32(r0 + 28); - - float32x4_t _tmp0m = vmlaq_n_f32(vsubq_f32(_r00, _r06), vsubq_f32(_r04, _r02), 5.25f); - float32x4_t _tmp7m = vmlaq_n_f32(vsubq_f32(_r07, _r01), vsubq_f32(_r03, _r05), 5.25f); - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_r02, _r06), _r04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); - - float32x4_t _tmp1m = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _tmp2m = vsubq_f32(_tmp12a, _tmp12b); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_r06, _r02, 0.25f), _r04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float32x4_t _tmp3m = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _tmp4m = vsubq_f32(_tmp34a, _tmp34b); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_r06, vmlsq_n_f32(_r02, _r04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float32x4_t _tmp5m = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _tmp6m = vsubq_f32(_tmp56a, _tmp56b); - vst1q_f32(tmp[5][m], _tmp5m); - vst1q_f32(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 8; - float* r0_tm_3 = r0_tm_0 + tiles * 12; - float* r0_tm_4 = r0_tm_0 + tiles * 16; - float* r0_tm_5 = r0_tm_0 + tiles * 20; - float* r0_tm_6 = r0_tm_0 + tiles * 24; - float* r0_tm_7 = r0_tm_0 + tiles * 28; - - for (int m = 0; m < 8; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); - float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); - - float32x4_t _r0tm0 = vmlaq_n_f32(vsubq_f32(_tmp00, _tmp06), vsubq_f32(_tmp04, _tmp02), 5.25f); - float32x4_t _r0tm7 = vmlaq_n_f32(vsubq_f32(_tmp07, _tmp01), vsubq_f32(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_tmp02, _tmp06), _tmp04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float32x4_t _r0tm1 = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _r0tm2 = vsubq_f32(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float32x4_t _r0tm3 = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _r0tm4 = vsubq_f32(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_tmp06, vmlsq_n_f32(_tmp02, _tmp04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float32x4_t _r0tm5 = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _r0tm6 = vsubq_f32(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1q_f32(r0_tm_0, _r0tm0); - vst1q_f32(r0_tm_1, _r0tm1); - vst1q_f32(r0_tm_2, _r0tm2); - vst1q_f32(r0_tm_3, _r0tm3); - vst1q_f32(r0_tm_4, _r0tm4); - vst1q_f32(r0_tm_5, _r0tm5); - vst1q_f32(r0_tm_6, _r0tm6); - vst1q_f32(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 32; - r0_tm_1 += tiles * 32; - r0_tm_2 += tiles * 32; - r0_tm_3 += tiles * 32; - r0_tm_4 += tiles * 32; - r0_tm_5 += tiles * 32; - r0_tm_6 += tiles * 32; - r0_tm_7 += tiles * 32; - } - } - } - } + conv3x3s1_winograd64_transform_input_pack4_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1874,173 +1680,7 @@ static void conv3x3s1_winograd64_pack4_neon(const Mat& bottom_blob, Mat& top_blo top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float32x4_t _bias0 = bias ? vld1q_f32((const float*)bias + p * 4) : vdupq_n_f32(0.f); - - float tmp[6][8][4]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 8; - const float* output0_tm_3 = output0_tm_0 + tiles * 12; - const float* output0_tm_4 = output0_tm_0 + tiles * 16; - const float* output0_tm_5 = output0_tm_0 + tiles * 20; - const float* output0_tm_6 = output0_tm_0 + tiles * 24; - const float* output0_tm_7 = output0_tm_0 + tiles * 28; - - float* output0 = out0.row(i * 6) + (j * 6) * 4; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); - float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); - float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); - float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); - float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); - float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); - float32x4_t _out0tm6 = vld1q_f32(output0_tm_6); - float32x4_t _out0tm7 = vld1q_f32(output0_tm_7); - - float32x4_t _tmp024a = vaddq_f32(_out0tm1, _out0tm2); - float32x4_t _tmp135a = vsubq_f32(_out0tm1, _out0tm2); - - // float tmp024a = output0_tm[1] + output0_tm[2]; - // float tmp135a = output0_tm[1] - output0_tm[2]; - - float32x4_t _tmp024b = vaddq_f32(_out0tm3, _out0tm4); - float32x4_t _tmp135b = vsubq_f32(_out0tm3, _out0tm4); - - // float tmp024b = output0_tm[3] + output0_tm[4]; - // float tmp135b = output0_tm[3] - output0_tm[4]; - - float32x4_t _tmp024c = vaddq_f32(_out0tm5, _out0tm6); - float32x4_t _tmp135c = vsubq_f32(_out0tm5, _out0tm6); - - // float tmp024c = output0_tm[5] + output0_tm[6]; - // float tmp135c = output0_tm[5] - output0_tm[6]; - - float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f)); - float32x4_t _tmp2m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); - float32x4_t _tmp4m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[2][m], _tmp2m); - vst1q_f32(tmp[4][m], _tmp4m); - - // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32; - // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float32x4_t _tmp1m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); - float32x4_t _tmp3m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); - float32x4_t _tmp5m = vaddq_f32(vaddq_f32(_out0tm7, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f)); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[5][m], _tmp5m); - - // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 32; - output0_tm_1 += tiles * 32; - output0_tm_2 += tiles * 32; - output0_tm_3 += tiles * 32; - output0_tm_4 += tiles * 32; - output0_tm_5 += tiles * 32; - output0_tm_6 += tiles * 32; - output0_tm_7 += tiles * 32; - } - - for (int m = 0; m < 6; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); - float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); - - float32x4_t _tmp024a = vaddq_f32(_tmp01, _tmp02); - float32x4_t _tmp135a = vsubq_f32(_tmp01, _tmp02); - - // float tmp024a = tmp0[1] + tmp0[2]; - // float tmp135a = tmp0[1] - tmp0[2]; - - float32x4_t _tmp024b = vaddq_f32(_tmp03, _tmp04); - float32x4_t _tmp135b = vsubq_f32(_tmp03, _tmp04); - - // float tmp024b = tmp0[3] + tmp0[4]; - // float tmp135b = tmp0[3] - tmp0[4]; - - float32x4_t _tmp024c = vaddq_f32(_tmp05, _tmp06); - float32x4_t _tmp135c = vsubq_f32(_tmp05, _tmp06); - - // float tmp024c = tmp0[5] + tmp0[6]; - // float tmp135c = tmp0[5] - tmp0[6]; - - float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f))); - float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); - float32x4_t _out04 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); - vst1q_f32(output0, _out00); - vst1q_f32(output0 + 8, _out02); - vst1q_f32(output0 + 16, _out04); - - // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); - float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); - float32x4_t _out05 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp07, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f))); - vst1q_f32(output0 + 4, _out01); - vst1q_f32(output0 + 12, _out03); - vst1q_f32(output0 + 20, _out05); - - // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack4_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -2277,7 +1917,7 @@ static void conv3x3s1_winograd42_transform_kernel_pack4_neon(const Mat& kernel, } } -static void conv3x3s1_winograd42_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -2299,115 +1939,15 @@ static void conv3x3s1_winograd42_pack4_neon(const Mat& bottom_blob, Mat& top_blo h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; + int w_tiles = outw / 4; + int h_tiles = outh / 4; + int tiles = w_tiles * h_tiles; bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[6][6][4]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const float* r0 = img0.row(i * 4) + (j * 4) * 4; - - for (int m = 0; m < 6; m++) - { - float32x4_t _r00 = vld1q_f32(r0); - float32x4_t _r01 = vld1q_f32(r0 + 4); - float32x4_t _r02 = vld1q_f32(r0 + 8); - float32x4_t _r03 = vld1q_f32(r0 + 12); - float32x4_t _r04 = vld1q_f32(r0 + 16); - float32x4_t _r05 = vld1q_f32(r0 + 20); - - float32x4_t _tmp0m = vmlsq_n_f32(vmlaq_n_f32(_r04, _r00, 4.f), _r02, 5.f); - float32x4_t _tmp1m = vmlsq_n_f32(vaddq_f32(_r04, _r03), vaddq_f32(_r01, _r02), 4.f); - float32x4_t _tmp2m = vmlaq_n_f32(vsubq_f32(_r04, _r03), vsubq_f32(_r01, _r02), 4.f); - float32x4_t _tmp3m = vmlsq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); - float32x4_t _tmp4m = vmlaq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); - float32x4_t _tmp5m = vmlsq_n_f32(vmlaq_n_f32(_r05, _r01, 4.f), _r03, 5.f); - - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[4][m], _tmp4m); - vst1q_f32(tmp[5][m], _tmp5m); - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 8; - float* r0_tm_3 = r0_tm_0 + tiles * 12; - float* r0_tm_4 = r0_tm_0 + tiles * 16; - float* r0_tm_5 = r0_tm_0 + tiles * 20; - - for (int m = 0; m < 6; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - - float32x4_t _r0tm0 = vmlsq_n_f32(vmlaq_n_f32(_tmp04, _tmp00, 4.f), _tmp02, 5.f); - float32x4_t _r0tm1 = vmlsq_n_f32(vaddq_f32(_tmp04, _tmp03), vaddq_f32(_tmp01, _tmp02), 4.f); - float32x4_t _r0tm2 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp03), vsubq_f32(_tmp01, _tmp02), 4.f); - float32x4_t _r0tm3 = vmlsq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); - float32x4_t _r0tm4 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); - float32x4_t _r0tm5 = vmlsq_n_f32(vmlaq_n_f32(_tmp05, _tmp01, 4.f), _tmp03, 5.f); - - vst1q_f32(r0_tm_0, _r0tm0); - vst1q_f32(r0_tm_1, _r0tm1); - vst1q_f32(r0_tm_2, _r0tm2); - vst1q_f32(r0_tm_3, _r0tm3); - vst1q_f32(r0_tm_4, _r0tm4); - vst1q_f32(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 24; - r0_tm_1 += tiles * 24; - r0_tm_2 += tiles * 24; - r0_tm_3 += tiles * 24; - r0_tm_4 += tiles * 24; - r0_tm_5 += tiles * 24; - } - } - } - } + conv3x3s1_winograd42_transform_input_pack4_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -3813,113 +3353,7 @@ static void conv3x3s1_winograd42_pack4_neon(const Mat& bottom_blob, Mat& top_blo top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float32x4_t _bias0 = bias ? vld1q_f32((const float*)bias + p * 4) : vdupq_n_f32(0.f); - - float tmp[4][6][4]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 8; - const float* output0_tm_3 = output0_tm_0 + tiles * 12; - const float* output0_tm_4 = output0_tm_0 + tiles * 16; - const float* output0_tm_5 = output0_tm_0 + tiles * 20; - - float* output0 = out0.row(i * 4) + (j * 4) * 4; - - // TODO neon optimize - for (int m = 0; m < 6; m++) - { - float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); - float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); - float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); - float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); - float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); - float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); - - float32x4_t _tmp02a = vaddq_f32(_out0tm1, _out0tm2); - float32x4_t _tmp13a = vsubq_f32(_out0tm1, _out0tm2); - - float32x4_t _tmp02b = vaddq_f32(_out0tm3, _out0tm4); - float32x4_t _tmp13b = vsubq_f32(_out0tm3, _out0tm4); - - float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp02a), _tmp02b); - float32x4_t _tmp1m = vmlaq_n_f32(_tmp13a, _tmp13b, 2.f); - float32x4_t _tmp2m = vmlaq_n_f32(_tmp02a, _tmp02b, 4.f); - float32x4_t _tmp3m = vmlaq_n_f32(vaddq_f32(_out0tm5, _tmp13a), _tmp13b, 8.f); - - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - vst1q_f32(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - - for (int m = 0; m < 4; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - - float32x4_t _tmp02a = vaddq_f32(_tmp01, _tmp02); - float32x4_t _tmp13a = vsubq_f32(_tmp01, _tmp02); - - float32x4_t _tmp02b = vaddq_f32(_tmp03, _tmp04); - float32x4_t _tmp13b = vsubq_f32(_tmp03, _tmp04); - - float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp02a), _tmp02b)); - float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp13a, _tmp13b, 2.f)); - float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp02a, _tmp02b, 4.f)); - float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vaddq_f32(_tmp05, _tmp13a), _tmp13b, 8.f)); - - vst1q_f32(output0, _out00); - vst1q_f32(output0 + 4, _out01); - vst1q_f32(output0 + 8, _out02); - vst1q_f32(output0 + 12, _out03); - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd42_transform_output_pack4_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_3x3_pack4_bf16s.h b/src/layer/arm/convolution_3x3_pack4_bf16s.h index 3cb25e43c18..4eb0870cc0c 100644 --- a/src/layer/arm/convolution_3x3_pack4_bf16s.h +++ b/src/layer/arm/convolution_3x3_pack4_bf16s.h @@ -12,7 +12,7 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv3x3s1_winograd64_pack4_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -34,210 +34,15 @@ static void conv3x3s1_winograd64_pack4_bf16s_neon(const Mat& bottom_blob, Mat& t h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; - - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[8][8][4]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const unsigned short* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - float32x4_t _r00 = vcvt_f32_bf16(vld1_u16(r0)); - float32x4_t _r01 = vcvt_f32_bf16(vld1_u16(r0 + 4)); - float32x4_t _r02 = vcvt_f32_bf16(vld1_u16(r0 + 8)); - float32x4_t _r03 = vcvt_f32_bf16(vld1_u16(r0 + 12)); - float32x4_t _r04 = vcvt_f32_bf16(vld1_u16(r0 + 16)); - float32x4_t _r05 = vcvt_f32_bf16(vld1_u16(r0 + 20)); - float32x4_t _r06 = vcvt_f32_bf16(vld1_u16(r0 + 24)); - float32x4_t _r07 = vcvt_f32_bf16(vld1_u16(r0 + 28)); - - float32x4_t _tmp0m = vmlaq_n_f32(vsubq_f32(_r00, _r06), vsubq_f32(_r04, _r02), 5.25f); - float32x4_t _tmp7m = vmlaq_n_f32(vsubq_f32(_r07, _r01), vsubq_f32(_r03, _r05), 5.25f); - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_r02, _r06), _r04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; - float32x4_t _tmp1m = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _tmp2m = vsubq_f32(_tmp12a, _tmp12b); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_r06, _r02, 0.25f), _r04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float32x4_t _tmp3m = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _tmp4m = vsubq_f32(_tmp34a, _tmp34b); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_r06, vmlsq_n_f32(_r02, _r04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float32x4_t _tmp5m = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _tmp6m = vsubq_f32(_tmp56a, _tmp56b); - vst1q_f32(tmp[5][m], _tmp5m); - vst1q_f32(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 8; - float* r0_tm_3 = r0_tm_0 + tiles * 12; - float* r0_tm_4 = r0_tm_0 + tiles * 16; - float* r0_tm_5 = r0_tm_0 + tiles * 20; - float* r0_tm_6 = r0_tm_0 + tiles * 24; - float* r0_tm_7 = r0_tm_0 + tiles * 28; - - for (int m = 0; m < 8; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); - float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); - - float32x4_t _r0tm0 = vmlaq_n_f32(vsubq_f32(_tmp00, _tmp06), vsubq_f32(_tmp04, _tmp02), 5.25f); - float32x4_t _r0tm7 = vmlaq_n_f32(vsubq_f32(_tmp07, _tmp01), vsubq_f32(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_tmp02, _tmp06), _tmp04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float32x4_t _r0tm1 = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _r0tm2 = vsubq_f32(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float32x4_t _r0tm3 = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _r0tm4 = vsubq_f32(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_tmp06, vmlsq_n_f32(_tmp02, _tmp04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float32x4_t _r0tm5 = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _r0tm6 = vsubq_f32(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1q_f32(r0_tm_0, _r0tm0); - vst1q_f32(r0_tm_1, _r0tm1); - vst1q_f32(r0_tm_2, _r0tm2); - vst1q_f32(r0_tm_3, _r0tm3); - vst1q_f32(r0_tm_4, _r0tm4); - vst1q_f32(r0_tm_5, _r0tm5); - vst1q_f32(r0_tm_6, _r0tm6); - vst1q_f32(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 32; - r0_tm_1 += tiles * 32; - r0_tm_2 += tiles * 32; - r0_tm_3 += tiles * 32; - r0_tm_4 += tiles * 32; - r0_tm_5 += tiles * 32; - r0_tm_6 += tiles * 32; - r0_tm_7 += tiles * 32; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, 16u, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack4_bf16s_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1643,173 +1448,7 @@ static void conv3x3s1_winograd64_pack4_bf16s_neon(const Mat& bottom_blob, Mat& t top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float32x4_t _bias0 = bias ? vld1q_f32((const float*)bias + p * 4) : vdupq_n_f32(0.f); - - float tmp[6][8][4]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 8; - const float* output0_tm_3 = output0_tm_0 + tiles * 12; - const float* output0_tm_4 = output0_tm_0 + tiles * 16; - const float* output0_tm_5 = output0_tm_0 + tiles * 20; - const float* output0_tm_6 = output0_tm_0 + tiles * 24; - const float* output0_tm_7 = output0_tm_0 + tiles * 28; - - unsigned short* output0 = out0.row(i * 6) + (j * 6) * 4; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); - float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); - float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); - float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); - float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); - float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); - float32x4_t _out0tm6 = vld1q_f32(output0_tm_6); - float32x4_t _out0tm7 = vld1q_f32(output0_tm_7); - - float32x4_t _tmp024a = vaddq_f32(_out0tm1, _out0tm2); - float32x4_t _tmp135a = vsubq_f32(_out0tm1, _out0tm2); - - // float tmp024a = output0_tm[1] + output0_tm[2]; - // float tmp135a = output0_tm[1] - output0_tm[2]; - - float32x4_t _tmp024b = vaddq_f32(_out0tm3, _out0tm4); - float32x4_t _tmp135b = vsubq_f32(_out0tm3, _out0tm4); - - // float tmp024b = output0_tm[3] + output0_tm[4]; - // float tmp135b = output0_tm[3] - output0_tm[4]; - - float32x4_t _tmp024c = vaddq_f32(_out0tm5, _out0tm6); - float32x4_t _tmp135c = vsubq_f32(_out0tm5, _out0tm6); - - // float tmp024c = output0_tm[5] + output0_tm[6]; - // float tmp135c = output0_tm[5] - output0_tm[6]; - - float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f)); - float32x4_t _tmp2m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); - float32x4_t _tmp4m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[2][m], _tmp2m); - vst1q_f32(tmp[4][m], _tmp4m); - - // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32; - // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float32x4_t _tmp1m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); - float32x4_t _tmp3m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); - float32x4_t _tmp5m = vaddq_f32(vaddq_f32(_out0tm7, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f)); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[5][m], _tmp5m); - - // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 32; - output0_tm_1 += tiles * 32; - output0_tm_2 += tiles * 32; - output0_tm_3 += tiles * 32; - output0_tm_4 += tiles * 32; - output0_tm_5 += tiles * 32; - output0_tm_6 += tiles * 32; - output0_tm_7 += tiles * 32; - } - - for (int m = 0; m < 6; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); - float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); - - float32x4_t _tmp024a = vaddq_f32(_tmp01, _tmp02); - float32x4_t _tmp135a = vsubq_f32(_tmp01, _tmp02); - - // float tmp024a = tmp0[1] + tmp0[2]; - // float tmp135a = tmp0[1] - tmp0[2]; - - float32x4_t _tmp024b = vaddq_f32(_tmp03, _tmp04); - float32x4_t _tmp135b = vsubq_f32(_tmp03, _tmp04); - - // float tmp024b = tmp0[3] + tmp0[4]; - // float tmp135b = tmp0[3] - tmp0[4]; - - float32x4_t _tmp024c = vaddq_f32(_tmp05, _tmp06); - float32x4_t _tmp135c = vsubq_f32(_tmp05, _tmp06); - - // float tmp024c = tmp0[5] + tmp0[6]; - // float tmp135c = tmp0[5] - tmp0[6]; - - float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f))); - float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); - float32x4_t _out04 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); - vst1_u16(output0, vcvt_bf16_f32(_out00)); - vst1_u16(output0 + 8, vcvt_bf16_f32(_out02)); - vst1_u16(output0 + 16, vcvt_bf16_f32(_out04)); - - // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); - float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); - float32x4_t _out05 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp07, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f))); - vst1_u16(output0 + 4, vcvt_bf16_f32(_out01)); - vst1_u16(output0 + 12, vcvt_bf16_f32(_out03)); - vst1_u16(output0 + 20, vcvt_bf16_f32(_out05)); - - // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack4_bf16s_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -1817,7 +1456,7 @@ static void conv3x3s1_winograd64_pack4_bf16s_neon(const Mat& bottom_blob, Mat& t copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt); } -static void conv3x3s1_winograd42_pack4_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_pack4_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -1839,115 +1478,15 @@ static void conv3x3s1_winograd42_pack4_bf16s_neon(const Mat& bottom_blob, Mat& t h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 4u * elempack, elempack, opt.workspace_allocator); + int w_tiles = outw / 4; + int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[6][6][4]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const unsigned short* r0 = img0.row(i * 4) + (j * 4) * 4; - - for (int m = 0; m < 6; m++) - { - float32x4_t _r00 = vcvt_f32_bf16(vld1_u16(r0)); - float32x4_t _r01 = vcvt_f32_bf16(vld1_u16(r0 + 4)); - float32x4_t _r02 = vcvt_f32_bf16(vld1_u16(r0 + 8)); - float32x4_t _r03 = vcvt_f32_bf16(vld1_u16(r0 + 12)); - float32x4_t _r04 = vcvt_f32_bf16(vld1_u16(r0 + 16)); - float32x4_t _r05 = vcvt_f32_bf16(vld1_u16(r0 + 20)); - - float32x4_t _tmp0m = vmlsq_n_f32(vmlaq_n_f32(_r04, _r00, 4.f), _r02, 5.f); - float32x4_t _tmp1m = vmlsq_n_f32(vaddq_f32(_r04, _r03), vaddq_f32(_r01, _r02), 4.f); - float32x4_t _tmp2m = vmlaq_n_f32(vsubq_f32(_r04, _r03), vsubq_f32(_r01, _r02), 4.f); - float32x4_t _tmp3m = vmlsq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); - float32x4_t _tmp4m = vmlaq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); - float32x4_t _tmp5m = vmlsq_n_f32(vmlaq_n_f32(_r05, _r01, 4.f), _r03, 5.f); - - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[4][m], _tmp4m); - vst1q_f32(tmp[5][m], _tmp5m); - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 8; - float* r0_tm_3 = r0_tm_0 + tiles * 12; - float* r0_tm_4 = r0_tm_0 + tiles * 16; - float* r0_tm_5 = r0_tm_0 + tiles * 20; - - for (int m = 0; m < 6; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - - float32x4_t _r0tm0 = vmlsq_n_f32(vmlaq_n_f32(_tmp04, _tmp00, 4.f), _tmp02, 5.f); - float32x4_t _r0tm1 = vmlsq_n_f32(vaddq_f32(_tmp04, _tmp03), vaddq_f32(_tmp01, _tmp02), 4.f); - float32x4_t _r0tm2 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp03), vsubq_f32(_tmp01, _tmp02), 4.f); - float32x4_t _r0tm3 = vmlsq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); - float32x4_t _r0tm4 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); - float32x4_t _r0tm5 = vmlsq_n_f32(vmlaq_n_f32(_tmp05, _tmp01, 4.f), _tmp03, 5.f); - - vst1q_f32(r0_tm_0, _r0tm0); - vst1q_f32(r0_tm_1, _r0tm1); - vst1q_f32(r0_tm_2, _r0tm2); - vst1q_f32(r0_tm_3, _r0tm3); - vst1q_f32(r0_tm_4, _r0tm4); - vst1q_f32(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 24; - r0_tm_1 += tiles * 24; - r0_tm_2 += tiles * 24; - r0_tm_3 += tiles * 24; - r0_tm_4 += tiles * 24; - r0_tm_5 += tiles * 24; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, 16u, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_pack4_bf16s_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -3353,113 +2892,7 @@ static void conv3x3s1_winograd42_pack4_bf16s_neon(const Mat& bottom_blob, Mat& t top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float32x4_t _bias0 = bias ? vld1q_f32((const float*)bias + p * 4) : vdupq_n_f32(0.f); - - float tmp[4][6][4]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 8; - const float* output0_tm_3 = output0_tm_0 + tiles * 12; - const float* output0_tm_4 = output0_tm_0 + tiles * 16; - const float* output0_tm_5 = output0_tm_0 + tiles * 20; - - unsigned short* output0 = out0.row(i * 4) + (j * 4) * 4; - - // TODO neon optimize - for (int m = 0; m < 6; m++) - { - float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); - float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); - float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); - float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); - float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); - float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); - - float32x4_t _tmp02a = vaddq_f32(_out0tm1, _out0tm2); - float32x4_t _tmp13a = vsubq_f32(_out0tm1, _out0tm2); - - float32x4_t _tmp02b = vaddq_f32(_out0tm3, _out0tm4); - float32x4_t _tmp13b = vsubq_f32(_out0tm3, _out0tm4); - - float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp02a), _tmp02b); - float32x4_t _tmp1m = vmlaq_n_f32(_tmp13a, _tmp13b, 2.f); - float32x4_t _tmp2m = vmlaq_n_f32(_tmp02a, _tmp02b, 4.f); - float32x4_t _tmp3m = vmlaq_n_f32(vaddq_f32(_out0tm5, _tmp13a), _tmp13b, 8.f); - - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - vst1q_f32(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 24; - output0_tm_1 += tiles * 24; - output0_tm_2 += tiles * 24; - output0_tm_3 += tiles * 24; - output0_tm_4 += tiles * 24; - output0_tm_5 += tiles * 24; - } - - for (int m = 0; m < 4; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - - float32x4_t _tmp02a = vaddq_f32(_tmp01, _tmp02); - float32x4_t _tmp13a = vsubq_f32(_tmp01, _tmp02); - - float32x4_t _tmp02b = vaddq_f32(_tmp03, _tmp04); - float32x4_t _tmp13b = vsubq_f32(_tmp03, _tmp04); - - float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp02a), _tmp02b)); - float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp13a, _tmp13b, 2.f)); - float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp02a, _tmp02b, 4.f)); - float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vaddq_f32(_tmp05, _tmp13a), _tmp13b, 8.f)); - - vst1_u16(output0, vcvt_bf16_f32(_out00)); - vst1_u16(output0 + 4, vcvt_bf16_f32(_out01)); - vst1_u16(output0 + 8, vcvt_bf16_f32(_out02)); - vst1_u16(output0 + 12, vcvt_bf16_f32(_out03)); - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd42_transform_output_pack4_bf16s_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_3x3_pack4_fp16s.h b/src/layer/arm/convolution_3x3_pack4_fp16s.h index 673a5a741ad..1e7100fa1e9 100644 --- a/src/layer/arm/convolution_3x3_pack4_fp16s.h +++ b/src/layer/arm/convolution_3x3_pack4_fp16s.h @@ -234,12 +234,12 @@ static void conv3x3s1_winograd64_transform_kernel_pack4_fp16sa_neon(const Mat& k } } -static void conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; int inch = bottom_blob.c; - //size_t elemsize = bottom_blob.elemsize; + size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; int outw = top_blob.w; @@ -256,210 +256,15 @@ static void conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; - - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - __fp16 tmp[8][8][4]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const __fp16* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - float16x4_t _r00 = vld1_f16(r0); - float16x4_t _r01 = vld1_f16(r0 + 4); - float16x4_t _r02 = vld1_f16(r0 + 8); - float16x4_t _r03 = vld1_f16(r0 + 12); - float16x4_t _r04 = vld1_f16(r0 + 16); - float16x4_t _r05 = vld1_f16(r0 + 20); - float16x4_t _r06 = vld1_f16(r0 + 24); - float16x4_t _r07 = vld1_f16(r0 + 28); - - float16x4_t _tmp0m = vfma_n_f16(vsub_f16(_r00, _r06), vsub_f16(_r04, _r02), 5.25f); - float16x4_t _tmp7m = vfma_n_f16(vsub_f16(_r07, _r01), vsub_f16(_r03, _r05), 5.25f); - vst1_f16(tmp[0][m], _tmp0m); - vst1_f16(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_r02, _r06), _r04, 4.25f); - float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); - - float16x4_t _tmp1m = vadd_f16(_tmp12a, _tmp12b); - float16x4_t _tmp2m = vsub_f16(_tmp12a, _tmp12b); - vst1_f16(tmp[1][m], _tmp1m); - vst1_f16(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_r06, _r02, 0.25f), _r04, 1.25f); - float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float16x4_t _tmp3m = vadd_f16(_tmp34a, _tmp34b); - float16x4_t _tmp4m = vsub_f16(_tmp34a, _tmp34b); - vst1_f16(tmp[3][m], _tmp3m); - vst1_f16(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float16x4_t _tmp56a = vfma_n_f16(_r06, vfms_n_f16(_r02, _r04, 1.25f), 4.f); - float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float16x4_t _tmp5m = vadd_f16(_tmp56a, _tmp56b); - float16x4_t _tmp6m = vsub_f16(_tmp56a, _tmp56b); - vst1_f16(tmp[5][m], _tmp5m); - vst1_f16(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; - r0 += w * 4; - } - - __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 4; - __fp16* r0_tm_1 = r0_tm_0 + tiles * 4; - __fp16* r0_tm_2 = r0_tm_0 + tiles * 8; - __fp16* r0_tm_3 = r0_tm_0 + tiles * 12; - __fp16* r0_tm_4 = r0_tm_0 + tiles * 16; - __fp16* r0_tm_5 = r0_tm_0 + tiles * 20; - __fp16* r0_tm_6 = r0_tm_0 + tiles * 24; - __fp16* r0_tm_7 = r0_tm_0 + tiles * 28; - - for (int m = 0; m < 8; m++) - { - float16x4_t _tmp00 = vld1_f16(tmp[m][0]); - float16x4_t _tmp01 = vld1_f16(tmp[m][1]); - float16x4_t _tmp02 = vld1_f16(tmp[m][2]); - float16x4_t _tmp03 = vld1_f16(tmp[m][3]); - float16x4_t _tmp04 = vld1_f16(tmp[m][4]); - float16x4_t _tmp05 = vld1_f16(tmp[m][5]); - float16x4_t _tmp06 = vld1_f16(tmp[m][6]); - float16x4_t _tmp07 = vld1_f16(tmp[m][7]); - - float16x4_t _r0tm0 = vfma_n_f16(vsub_f16(_tmp00, _tmp06), vsub_f16(_tmp04, _tmp02), 5.25f); - float16x4_t _r0tm7 = vfma_n_f16(vsub_f16(_tmp07, _tmp01), vsub_f16(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_tmp02, _tmp06), _tmp04, 4.25f); - float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float16x4_t _r0tm1 = vadd_f16(_tmp12a, _tmp12b); - float16x4_t _r0tm2 = vsub_f16(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float16x4_t _r0tm3 = vadd_f16(_tmp34a, _tmp34b); - float16x4_t _r0tm4 = vsub_f16(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float16x4_t _tmp56a = vfma_n_f16(_tmp06, vfms_n_f16(_tmp02, _tmp04, 1.25f), 4.f); - float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float16x4_t _r0tm5 = vadd_f16(_tmp56a, _tmp56b); - float16x4_t _r0tm6 = vsub_f16(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1_f16(r0_tm_0, _r0tm0); - vst1_f16(r0_tm_1, _r0tm1); - vst1_f16(r0_tm_2, _r0tm2); - vst1_f16(r0_tm_3, _r0tm3); - vst1_f16(r0_tm_4, _r0tm4); - vst1_f16(r0_tm_5, _r0tm5); - vst1_f16(r0_tm_6, _r0tm6); - vst1_f16(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 32; - r0_tm_1 += tiles * 32; - r0_tm_2 += tiles * 32; - r0_tm_3 += tiles * 32; - r0_tm_4 += tiles * 32; - r0_tm_5 += tiles * 32; - r0_tm_6 += tiles * 32; - r0_tm_7 += tiles * 32; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack4_fp16sa_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -979,173 +784,7 @@ static void conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob_bordered.create(outw, outh, outch, 2u * 4, 4, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float16x4_t _bias0 = bias ? vld1_f16((const __fp16*)bias + p * 4) : vdup_n_f16(0.f); - - __fp16 tmp[6][8][4]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 4; - const __fp16* output0_tm_1 = output0_tm_0 + tiles * 4; - const __fp16* output0_tm_2 = output0_tm_0 + tiles * 8; - const __fp16* output0_tm_3 = output0_tm_0 + tiles * 12; - const __fp16* output0_tm_4 = output0_tm_0 + tiles * 16; - const __fp16* output0_tm_5 = output0_tm_0 + tiles * 20; - const __fp16* output0_tm_6 = output0_tm_0 + tiles * 24; - const __fp16* output0_tm_7 = output0_tm_0 + tiles * 28; - - __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 4; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - float16x4_t _out0tm0 = vld1_f16(output0_tm_0); - float16x4_t _out0tm1 = vld1_f16(output0_tm_1); - float16x4_t _out0tm2 = vld1_f16(output0_tm_2); - float16x4_t _out0tm3 = vld1_f16(output0_tm_3); - float16x4_t _out0tm4 = vld1_f16(output0_tm_4); - float16x4_t _out0tm5 = vld1_f16(output0_tm_5); - float16x4_t _out0tm6 = vld1_f16(output0_tm_6); - float16x4_t _out0tm7 = vld1_f16(output0_tm_7); - - float16x4_t _tmp024a = vadd_f16(_out0tm1, _out0tm2); - float16x4_t _tmp135a = vsub_f16(_out0tm1, _out0tm2); - - // float tmp024a = output0_tm[1] + output0_tm[2]; - // float tmp135a = output0_tm[1] - output0_tm[2]; - - float16x4_t _tmp024b = vadd_f16(_out0tm3, _out0tm4); - float16x4_t _tmp135b = vsub_f16(_out0tm3, _out0tm4); - - // float tmp024b = output0_tm[3] + output0_tm[4]; - // float tmp135b = output0_tm[3] - output0_tm[4]; - - float16x4_t _tmp024c = vadd_f16(_out0tm5, _out0tm6); - float16x4_t _tmp135c = vsub_f16(_out0tm5, _out0tm6); - - // float tmp024c = output0_tm[5] + output0_tm[6]; - // float tmp135c = output0_tm[5] - output0_tm[6]; - - float16x4_t _tmp0m = vadd_f16(vadd_f16(_out0tm0, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f)); - float16x4_t _tmp2m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); - float16x4_t _tmp4m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); - vst1_f16(tmp[0][m], _tmp0m); - vst1_f16(tmp[2][m], _tmp2m); - vst1_f16(tmp[4][m], _tmp4m); - - // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32; - // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float16x4_t _tmp1m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); - float16x4_t _tmp3m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); - float16x4_t _tmp5m = vadd_f16(vadd_f16(_out0tm7, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f)); - vst1_f16(tmp[1][m], _tmp1m); - vst1_f16(tmp[3][m], _tmp3m); - vst1_f16(tmp[5][m], _tmp5m); - - // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 32; - output0_tm_1 += tiles * 32; - output0_tm_2 += tiles * 32; - output0_tm_3 += tiles * 32; - output0_tm_4 += tiles * 32; - output0_tm_5 += tiles * 32; - output0_tm_6 += tiles * 32; - output0_tm_7 += tiles * 32; - } - - for (int m = 0; m < 6; m++) - { - float16x4_t _tmp00 = vld1_f16(tmp[m][0]); - float16x4_t _tmp01 = vld1_f16(tmp[m][1]); - float16x4_t _tmp02 = vld1_f16(tmp[m][2]); - float16x4_t _tmp03 = vld1_f16(tmp[m][3]); - float16x4_t _tmp04 = vld1_f16(tmp[m][4]); - float16x4_t _tmp05 = vld1_f16(tmp[m][5]); - float16x4_t _tmp06 = vld1_f16(tmp[m][6]); - float16x4_t _tmp07 = vld1_f16(tmp[m][7]); - - float16x4_t _tmp024a = vadd_f16(_tmp01, _tmp02); - float16x4_t _tmp135a = vsub_f16(_tmp01, _tmp02); - - // float tmp024a = tmp0[1] + tmp0[2]; - // float tmp135a = tmp0[1] - tmp0[2]; - - float16x4_t _tmp024b = vadd_f16(_tmp03, _tmp04); - float16x4_t _tmp135b = vsub_f16(_tmp03, _tmp04); - - // float tmp024b = tmp0[3] + tmp0[4]; - // float tmp135b = tmp0[3] - tmp0[4]; - - float16x4_t _tmp024c = vadd_f16(_tmp05, _tmp06); - float16x4_t _tmp135c = vsub_f16(_tmp05, _tmp06); - - // float tmp024c = tmp0[5] + tmp0[6]; - // float tmp135c = tmp0[5] - tmp0[6]; - - float16x4_t _out00 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp00, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f))); - float16x4_t _out02 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); - float16x4_t _out04 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); - vst1_f16(output0, _out00); - vst1_f16(output0 + 8, _out02); - vst1_f16(output0 + 16, _out04); - - // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float16x4_t _out01 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); - float16x4_t _out03 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); - float16x4_t _out05 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp07, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f))); - vst1_f16(output0 + 4, _out01); - vst1_f16(output0 + 12, _out03); - vst1_f16(output0 + 20, _out05); - - // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack4_fp16sa_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_3x3_pack4to1.h b/src/layer/arm/convolution_3x3_pack4to1.h index 76980a2717a..e05dc931261 100644 --- a/src/layer/arm/convolution_3x3_pack4to1.h +++ b/src/layer/arm/convolution_3x3_pack4to1.h @@ -270,7 +270,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack4to1_neon(const Mat& kerne } } -static void conv3x3s1_winograd64_pack4to1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4to1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -292,209 +292,15 @@ static void conv3x3s1_winograd64_pack4to1_neon(const Mat& bottom_blob, Mat& top_ h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + int tiles = w_tiles * h_tiles; bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[8][8][4]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - float32x4_t _r00 = vld1q_f32(r0); - float32x4_t _r01 = vld1q_f32(r0 + 4); - float32x4_t _r02 = vld1q_f32(r0 + 8); - float32x4_t _r03 = vld1q_f32(r0 + 12); - float32x4_t _r04 = vld1q_f32(r0 + 16); - float32x4_t _r05 = vld1q_f32(r0 + 20); - float32x4_t _r06 = vld1q_f32(r0 + 24); - float32x4_t _r07 = vld1q_f32(r0 + 28); - - float32x4_t _tmp0m = vmlaq_n_f32(vsubq_f32(_r00, _r06), vsubq_f32(_r04, _r02), 5.25f); - float32x4_t _tmp7m = vmlaq_n_f32(vsubq_f32(_r07, _r01), vsubq_f32(_r03, _r05), 5.25f); - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_r02, _r06), _r04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); - - float32x4_t _tmp1m = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _tmp2m = vsubq_f32(_tmp12a, _tmp12b); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_r06, _r02, 0.25f), _r04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float32x4_t _tmp3m = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _tmp4m = vsubq_f32(_tmp34a, _tmp34b); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_r06, vmlsq_n_f32(_r02, _r04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float32x4_t _tmp5m = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _tmp6m = vsubq_f32(_tmp56a, _tmp56b); - vst1q_f32(tmp[5][m], _tmp5m); - vst1q_f32(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 8; - float* r0_tm_3 = r0_tm_0 + tiles * 12; - float* r0_tm_4 = r0_tm_0 + tiles * 16; - float* r0_tm_5 = r0_tm_0 + tiles * 20; - float* r0_tm_6 = r0_tm_0 + tiles * 24; - float* r0_tm_7 = r0_tm_0 + tiles * 28; - - for (int m = 0; m < 8; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); - float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); - - float32x4_t _r0tm0 = vmlaq_n_f32(vsubq_f32(_tmp00, _tmp06), vsubq_f32(_tmp04, _tmp02), 5.25f); - float32x4_t _r0tm7 = vmlaq_n_f32(vsubq_f32(_tmp07, _tmp01), vsubq_f32(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_tmp02, _tmp06), _tmp04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float32x4_t _r0tm1 = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _r0tm2 = vsubq_f32(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float32x4_t _r0tm3 = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _r0tm4 = vsubq_f32(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_tmp06, vmlsq_n_f32(_tmp02, _tmp04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float32x4_t _r0tm5 = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _r0tm6 = vsubq_f32(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1q_f32(r0_tm_0, _r0tm0); - vst1q_f32(r0_tm_1, _r0tm1); - vst1q_f32(r0_tm_2, _r0tm2); - vst1q_f32(r0_tm_3, _r0tm3); - vst1q_f32(r0_tm_4, _r0tm4); - vst1q_f32(r0_tm_5, _r0tm5); - vst1q_f32(r0_tm_6, _r0tm6); - vst1q_f32(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 32; - r0_tm_1 += tiles * 32; - r0_tm_2 += tiles * 32; - r0_tm_3 += tiles * 32; - r0_tm_4 += tiles * 32; - r0_tm_5 += tiles * 32; - r0_tm_6 += tiles * 32; - r0_tm_7 += tiles * 32; - } - } - } - } + conv3x3s1_winograd64_transform_input_pack4_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -2182,111 +1988,7 @@ static void conv3x3s1_winograd64_pack4to1_neon(const Mat& bottom_blob, Mat& top_ top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - const float bias0 = bias ? bias[p] : 0.f; - // float32x2_t _bias0 = vdup_n_f32(bias0); - - float tmp[6][8]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, 4u, 1, opt.workspace_allocator); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 1; - const float* output0_tm_1 = output0_tm_0 + tiles * 1; - const float* output0_tm_2 = output0_tm_0 + tiles * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * 7; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - float tmp024a = output0_tm_1[0] + output0_tm_2[0]; - float tmp135a = output0_tm_1[0] - output0_tm_2[0]; - - float tmp024b = output0_tm_3[0] + output0_tm_4[0]; - float tmp135b = output0_tm_3[0] - output0_tm_4[0]; - - float tmp024c = output0_tm_5[0] + output0_tm_6[0]; - float tmp135c = output0_tm_5[0] - output0_tm_6[0]; - - tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; - tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 8; - output0_tm_1 += tiles * 8; - output0_tm_2 += tiles * 8; - output0_tm_3 += tiles * 8; - output0_tm_4 += tiles * 8; - output0_tm_5 += tiles * 8; - output0_tm_6 += tiles * 8; - output0_tm_7 += tiles * 8; - } - - float* output0 = out0.row(i * 6) + j * 6; - - for (int m = 0; m < 6; m++) - { - const float* tmp0 = tmp[m]; - - float tmp024a = tmp0[1] + tmp0[2]; - float tmp135a = tmp0[1] - tmp0[2]; - - float tmp024b = tmp0[3] + tmp0[4]; - float tmp135b = tmp0[3] - tmp0[4]; - - float tmp024c = tmp0[5] + tmp0[6]; - float tmp135c = tmp0[5] - tmp0[6]; - - output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw; - } - } - } - } + conv3x3s1_winograd64_transform_output_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_3x3_pack4to1_bf16s.h b/src/layer/arm/convolution_3x3_pack4to1_bf16s.h index 1ee51f1b675..207794158f3 100644 --- a/src/layer/arm/convolution_3x3_pack4to1_bf16s.h +++ b/src/layer/arm/convolution_3x3_pack4to1_bf16s.h @@ -12,12 +12,12 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -static void conv3x3s1_winograd64_pack4to1_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4to1_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; int inch = bottom_blob.c; - //size_t elemsize = bottom_blob.elemsize; + size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; int outw = top_blob.w; @@ -34,210 +34,15 @@ static void conv3x3s1_winograd64_pack4to1_bf16s_neon(const Mat& bottom_blob, Mat h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; - - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[8][8][4]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const unsigned short* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - float32x4_t _r00 = vcvt_f32_bf16(vld1_u16(r0)); - float32x4_t _r01 = vcvt_f32_bf16(vld1_u16(r0 + 4)); - float32x4_t _r02 = vcvt_f32_bf16(vld1_u16(r0 + 8)); - float32x4_t _r03 = vcvt_f32_bf16(vld1_u16(r0 + 12)); - float32x4_t _r04 = vcvt_f32_bf16(vld1_u16(r0 + 16)); - float32x4_t _r05 = vcvt_f32_bf16(vld1_u16(r0 + 20)); - float32x4_t _r06 = vcvt_f32_bf16(vld1_u16(r0 + 24)); - float32x4_t _r07 = vcvt_f32_bf16(vld1_u16(r0 + 28)); - - float32x4_t _tmp0m = vmlaq_n_f32(vsubq_f32(_r00, _r06), vsubq_f32(_r04, _r02), 5.25f); - float32x4_t _tmp7m = vmlaq_n_f32(vsubq_f32(_r07, _r01), vsubq_f32(_r03, _r05), 5.25f); - vst1q_f32(tmp[0][m], _tmp0m); - vst1q_f32(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_r02, _r06), _r04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); - - float32x4_t _tmp1m = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _tmp2m = vsubq_f32(_tmp12a, _tmp12b); - vst1q_f32(tmp[1][m], _tmp1m); - vst1q_f32(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_r06, _r02, 0.25f), _r04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float32x4_t _tmp3m = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _tmp4m = vsubq_f32(_tmp34a, _tmp34b); - vst1q_f32(tmp[3][m], _tmp3m); - vst1q_f32(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_r06, vmlsq_n_f32(_r02, _r04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float32x4_t _tmp5m = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _tmp6m = vsubq_f32(_tmp56a, _tmp56b); - vst1q_f32(tmp[5][m], _tmp5m); - vst1q_f32(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 8; - float* r0_tm_3 = r0_tm_0 + tiles * 12; - float* r0_tm_4 = r0_tm_0 + tiles * 16; - float* r0_tm_5 = r0_tm_0 + tiles * 20; - float* r0_tm_6 = r0_tm_0 + tiles * 24; - float* r0_tm_7 = r0_tm_0 + tiles * 28; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; - for (int m = 0; m < 8; m++) - { - float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); - float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); - float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); - float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); - float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); - float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); - float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); - float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); - - float32x4_t _r0tm0 = vmlaq_n_f32(vsubq_f32(_tmp00, _tmp06), vsubq_f32(_tmp04, _tmp02), 5.25f); - float32x4_t _r0tm7 = vmlaq_n_f32(vsubq_f32(_tmp07, _tmp01), vsubq_f32(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_tmp02, _tmp06), _tmp04, 4.25f); - float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float32x4_t _r0tm1 = vaddq_f32(_tmp12a, _tmp12b); - float32x4_t _r0tm2 = vsubq_f32(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float32x4_t _r0tm3 = vaddq_f32(_tmp34a, _tmp34b); - float32x4_t _r0tm4 = vsubq_f32(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float32x4_t _tmp56a = vmlaq_n_f32(_tmp06, vmlsq_n_f32(_tmp02, _tmp04, 1.25f), 4.f); - float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float32x4_t _r0tm5 = vaddq_f32(_tmp56a, _tmp56b); - float32x4_t _r0tm6 = vsubq_f32(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1q_f32(r0_tm_0, _r0tm0); - vst1q_f32(r0_tm_1, _r0tm1); - vst1q_f32(r0_tm_2, _r0tm2); - vst1q_f32(r0_tm_3, _r0tm3); - vst1q_f32(r0_tm_4, _r0tm4); - vst1q_f32(r0_tm_5, _r0tm5); - vst1q_f32(r0_tm_6, _r0tm6); - vst1q_f32(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 32; - r0_tm_1 += tiles * 32; - r0_tm_2 += tiles * 32; - r0_tm_3 += tiles * 32; - r0_tm_4 += tiles * 32; - r0_tm_5 += tiles * 32; - r0_tm_6 += tiles * 32; - r0_tm_7 += tiles * 32; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, 16u, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack4_bf16s_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1925,111 +1730,7 @@ static void conv3x3s1_winograd64_pack4to1_bf16s_neon(const Mat& bottom_blob, Mat top_blob_bordered.create(outw, outh, outch, 2u, 1, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - const float bias0 = bias ? bias[p] : 0.f; - // float32x2_t _bias0 = vdup_n_f32(bias0); - - float tmp[6][8]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, 4u, 1, opt.workspace_allocator); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 1; - const float* output0_tm_1 = output0_tm_0 + tiles * 1; - const float* output0_tm_2 = output0_tm_0 + tiles * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * 7; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - float tmp024a = output0_tm_1[0] + output0_tm_2[0]; - float tmp135a = output0_tm_1[0] - output0_tm_2[0]; - - float tmp024b = output0_tm_3[0] + output0_tm_4[0]; - float tmp135b = output0_tm_3[0] - output0_tm_4[0]; - - float tmp024c = output0_tm_5[0] + output0_tm_6[0]; - float tmp135c = output0_tm_5[0] - output0_tm_6[0]; - - tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; - tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 8; - output0_tm_1 += tiles * 8; - output0_tm_2 += tiles * 8; - output0_tm_3 += tiles * 8; - output0_tm_4 += tiles * 8; - output0_tm_5 += tiles * 8; - output0_tm_6 += tiles * 8; - output0_tm_7 += tiles * 8; - } - - unsigned short* output0 = out0.row(i * 6) + j * 6; - - for (int m = 0; m < 6; m++) - { - const float* tmp0 = tmp[m]; - - float tmp024a = tmp0[1] + tmp0[2]; - float tmp135a = tmp0[1] - tmp0[2]; - - float tmp024b = tmp0[3] + tmp0[4]; - float tmp135b = tmp0[3] - tmp0[4]; - - float tmp024c = tmp0[5] + tmp0[6]; - float tmp135c = tmp0[5] - tmp0[6]; - - output0[0] = float32_to_bfloat16(bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32); - output0[2] = float32_to_bfloat16(bias0 + tmp024a + tmp024b * 4 + tmp024c * 8); - output0[4] = float32_to_bfloat16(bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c); - - output0[1] = float32_to_bfloat16(bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16); - output0[3] = float32_to_bfloat16(bias0 + tmp135a + tmp135b * 8 + tmp135c * 4); - output0[5] = float32_to_bfloat16(bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c); - - output0 += outw; - } - } - } - } + conv3x3s1_winograd64_transform_output_bf16s_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_3x3_pack8_fp16s.h b/src/layer/arm/convolution_3x3_pack8_fp16s.h index d67f995cdb3..148e02d2fda 100644 --- a/src/layer/arm/convolution_3x3_pack8_fp16s.h +++ b/src/layer/arm/convolution_3x3_pack8_fp16s.h @@ -116,7 +116,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack8_fp16sa_neon(const Mat& k } } -static void conv3x3s1_winograd64_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -138,210 +138,15 @@ static void conv3x3s1_winograd64_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const __fp16* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; - - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - __fp16 tmp[8][8][8]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const __fp16* r0 = img0.row(i * 6) + (j * 6) * 8; - - for (int m = 0; m < 8; m++) - { - float16x8_t _r00 = vld1q_f16(r0); - float16x8_t _r01 = vld1q_f16(r0 + 8); - float16x8_t _r02 = vld1q_f16(r0 + 16); - float16x8_t _r03 = vld1q_f16(r0 + 24); - float16x8_t _r04 = vld1q_f16(r0 + 32); - float16x8_t _r05 = vld1q_f16(r0 + 40); - float16x8_t _r06 = vld1q_f16(r0 + 48); - float16x8_t _r07 = vld1q_f16(r0 + 56); - - float16x8_t _tmp0m = vfmaq_n_f16(vsubq_f16(_r00, _r06), vsubq_f16(_r04, _r02), 5.25f); - float16x8_t _tmp7m = vfmaq_n_f16(vsubq_f16(_r07, _r01), vsubq_f16(_r03, _r05), 5.25f); - vst1q_f16(tmp[0][m], _tmp0m); - vst1q_f16(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_r02, _r06), _r04, 4.25f); - float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); + int w_tiles = outw / 6 * 8; + int h_tiles = outh / 6 * 8; + const int tiles = w_tiles * h_tiles; - float16x8_t _tmp1m = vaddq_f16(_tmp12a, _tmp12b); - float16x8_t _tmp2m = vsubq_f16(_tmp12a, _tmp12b); - vst1q_f16(tmp[1][m], _tmp1m); - vst1q_f16(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_r06, _r02, 0.25f), _r04, 1.25f); - float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float16x8_t _tmp3m = vaddq_f16(_tmp34a, _tmp34b); - float16x8_t _tmp4m = vsubq_f16(_tmp34a, _tmp34b); - vst1q_f16(tmp[3][m], _tmp3m); - vst1q_f16(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float16x8_t _tmp56a = vfmaq_n_f16(_r06, vfmsq_n_f16(_r02, _r04, 1.25f), 4.f); - float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float16x8_t _tmp5m = vaddq_f16(_tmp56a, _tmp56b); - float16x8_t _tmp6m = vsubq_f16(_tmp56a, _tmp56b); - vst1q_f16(tmp[5][m], _tmp5m); - vst1q_f16(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; - - r0 += w * 8; - } - - __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 8; - __fp16* r0_tm_1 = r0_tm_0 + tiles * 8; - __fp16* r0_tm_2 = r0_tm_0 + tiles * 16; - __fp16* r0_tm_3 = r0_tm_0 + tiles * 24; - __fp16* r0_tm_4 = r0_tm_0 + tiles * 32; - __fp16* r0_tm_5 = r0_tm_0 + tiles * 40; - __fp16* r0_tm_6 = r0_tm_0 + tiles * 48; - __fp16* r0_tm_7 = r0_tm_0 + tiles * 56; - - for (int m = 0; m < 8; m++) - { - float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); - float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); - float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); - float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); - float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); - float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); - float16x8_t _tmp06 = vld1q_f16(tmp[m][6]); - float16x8_t _tmp07 = vld1q_f16(tmp[m][7]); - - float16x8_t _r0tm0 = vfmaq_n_f16(vsubq_f16(_tmp00, _tmp06), vsubq_f16(_tmp04, _tmp02), 5.25f); - float16x8_t _r0tm7 = vfmaq_n_f16(vsubq_f16(_tmp07, _tmp01), vsubq_f16(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_tmp02, _tmp06), _tmp04, 4.25f); - float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float16x8_t _r0tm1 = vaddq_f16(_tmp12a, _tmp12b); - float16x8_t _r0tm2 = vsubq_f16(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float16x8_t _r0tm3 = vaddq_f16(_tmp34a, _tmp34b); - float16x8_t _r0tm4 = vsubq_f16(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float16x8_t _tmp56a = vfmaq_n_f16(_tmp06, vfmsq_n_f16(_tmp02, _tmp04, 1.25f), 4.f); - float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float16x8_t _r0tm5 = vaddq_f16(_tmp56a, _tmp56b); - float16x8_t _r0tm6 = vsubq_f16(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1q_f16(r0_tm_0, _r0tm0); - vst1q_f16(r0_tm_1, _r0tm1); - vst1q_f16(r0_tm_2, _r0tm2); - vst1q_f16(r0_tm_3, _r0tm3); - vst1q_f16(r0_tm_4, _r0tm4); - vst1q_f16(r0_tm_5, _r0tm5); - vst1q_f16(r0_tm_6, _r0tm6); - vst1q_f16(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 64; - r0_tm_1 += tiles * 64; - r0_tm_2 += tiles * 64; - r0_tm_3 += tiles * 64; - r0_tm_4 += tiles * 64; - r0_tm_5 += tiles * 64; - r0_tm_6 += tiles * 64; - r0_tm_7 += tiles * 64; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack8_fp16sa_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1033,173 +838,7 @@ static void conv3x3s1_winograd64_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float16x8_t _bias0 = bias ? vld1q_f16((const __fp16*)bias + p * 8) : vdupq_n_f16(0.f); - - __fp16 tmp[6][8][8]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 8; - const __fp16* output0_tm_1 = output0_tm_0 + tiles * 8; - const __fp16* output0_tm_2 = output0_tm_0 + tiles * 16; - const __fp16* output0_tm_3 = output0_tm_0 + tiles * 24; - const __fp16* output0_tm_4 = output0_tm_0 + tiles * 32; - const __fp16* output0_tm_5 = output0_tm_0 + tiles * 40; - const __fp16* output0_tm_6 = output0_tm_0 + tiles * 48; - const __fp16* output0_tm_7 = output0_tm_0 + tiles * 56; - - __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 8; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - float16x8_t _out0tm0 = vld1q_f16(output0_tm_0); - float16x8_t _out0tm1 = vld1q_f16(output0_tm_1); - float16x8_t _out0tm2 = vld1q_f16(output0_tm_2); - float16x8_t _out0tm3 = vld1q_f16(output0_tm_3); - float16x8_t _out0tm4 = vld1q_f16(output0_tm_4); - float16x8_t _out0tm5 = vld1q_f16(output0_tm_5); - float16x8_t _out0tm6 = vld1q_f16(output0_tm_6); - float16x8_t _out0tm7 = vld1q_f16(output0_tm_7); - - float16x8_t _tmp024a = vaddq_f16(_out0tm1, _out0tm2); - float16x8_t _tmp135a = vsubq_f16(_out0tm1, _out0tm2); - - // float tmp024a = output0_tm[1] + output0_tm[2]; - // float tmp135a = output0_tm[1] - output0_tm[2]; - - float16x8_t _tmp024b = vaddq_f16(_out0tm3, _out0tm4); - float16x8_t _tmp135b = vsubq_f16(_out0tm3, _out0tm4); - - // float tmp024b = output0_tm[3] + output0_tm[4]; - // float tmp135b = output0_tm[3] - output0_tm[4]; - - float16x8_t _tmp024c = vaddq_f16(_out0tm5, _out0tm6); - float16x8_t _tmp135c = vsubq_f16(_out0tm5, _out0tm6); - - // float tmp024c = output0_tm[5] + output0_tm[6]; - // float tmp135c = output0_tm[5] - output0_tm[6]; - - float16x8_t _tmp0m = vaddq_f16(vaddq_f16(_out0tm0, _tmp024a), vfmaq_n_f16(_tmp024b, _tmp024c, 32.f)); - float16x8_t _tmp2m = vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); - float16x8_t _tmp4m = vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); - vst1q_f16(tmp[0][m], _tmp0m); - vst1q_f16(tmp[2][m], _tmp2m); - vst1q_f16(tmp[4][m], _tmp4m); - - // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32; - // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float16x8_t _tmp1m = vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); - float16x8_t _tmp3m = vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); - float16x8_t _tmp5m = vaddq_f16(vaddq_f16(_out0tm7, _tmp135a), vfmaq_n_f16(_tmp135c, _tmp135b, 32.f)); - vst1q_f16(tmp[1][m], _tmp1m); - vst1q_f16(tmp[3][m], _tmp3m); - vst1q_f16(tmp[5][m], _tmp5m); - - // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 64; - output0_tm_1 += tiles * 64; - output0_tm_2 += tiles * 64; - output0_tm_3 += tiles * 64; - output0_tm_4 += tiles * 64; - output0_tm_5 += tiles * 64; - output0_tm_6 += tiles * 64; - output0_tm_7 += tiles * 64; - } - - for (int m = 0; m < 6; m++) - { - float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); - float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); - float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); - float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); - float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); - float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); - float16x8_t _tmp06 = vld1q_f16(tmp[m][6]); - float16x8_t _tmp07 = vld1q_f16(tmp[m][7]); - - float16x8_t _tmp024a = vaddq_f16(_tmp01, _tmp02); - float16x8_t _tmp135a = vsubq_f16(_tmp01, _tmp02); - - // float tmp024a = tmp0[1] + tmp0[2]; - // float tmp135a = tmp0[1] - tmp0[2]; - - float16x8_t _tmp024b = vaddq_f16(_tmp03, _tmp04); - float16x8_t _tmp135b = vsubq_f16(_tmp03, _tmp04); - - // float tmp024b = tmp0[3] + tmp0[4]; - // float tmp135b = tmp0[3] - tmp0[4]; - - float16x8_t _tmp024c = vaddq_f16(_tmp05, _tmp06); - float16x8_t _tmp135c = vsubq_f16(_tmp05, _tmp06); - - // float tmp024c = tmp0[5] + tmp0[6]; - // float tmp135c = tmp0[5] - tmp0[6]; - - float16x8_t _out00 = vaddq_f16(_bias0, vaddq_f16(vaddq_f16(_tmp00, _tmp024a), vfmaq_n_f16(_tmp024b, _tmp024c, 32.f))); - float16x8_t _out02 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); - float16x8_t _out04 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); - vst1q_f16(output0, _out00); - vst1q_f16(output0 + 16, _out02); - vst1q_f16(output0 + 32, _out04); - - // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float16x8_t _out01 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); - float16x8_t _out03 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); - float16x8_t _out05 = vaddq_f16(_bias0, vaddq_f16(vaddq_f16(_tmp07, _tmp135a), vfmaq_n_f16(_tmp135c, _tmp135b, 32.f))); - vst1q_f16(output0 + 8, _out01); - vst1q_f16(output0 + 24, _out03); - vst1q_f16(output0 + 40, _out05); - - // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw * 8; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack8_fp16sa_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -1308,7 +947,7 @@ static void conv3x3s1_winograd42_transform_kernel_pack8_fp16sa_neon(const Mat& k } } -static void conv3x3s1_winograd42_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -1330,115 +969,15 @@ static void conv3x3s1_winograd42_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const __fp16* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); + int w_tiles = outw / 4; + int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - __fp16 tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const __fp16* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - float16x8_t _r00 = vld1q_f16(r0); - float16x8_t _r01 = vld1q_f16(r0 + 8); - float16x8_t _r02 = vld1q_f16(r0 + 16); - float16x8_t _r03 = vld1q_f16(r0 + 24); - float16x8_t _r04 = vld1q_f16(r0 + 32); - float16x8_t _r05 = vld1q_f16(r0 + 40); - - float16x8_t _tmp0m = vfmsq_n_f16(vfmaq_n_f16(_r04, _r00, 4.f), _r02, 5.f); - float16x8_t _tmp1m = vfmsq_n_f16(vaddq_f16(_r04, _r03), vaddq_f16(_r01, _r02), 4.f); - float16x8_t _tmp2m = vfmaq_n_f16(vsubq_f16(_r04, _r03), vsubq_f16(_r01, _r02), 4.f); - float16x8_t _tmp3m = vfmsq_n_f16(vsubq_f16(_r04, _r02), vsubq_f16(_r01, _r03), 2.f); - float16x8_t _tmp4m = vfmaq_n_f16(vsubq_f16(_r04, _r02), vsubq_f16(_r01, _r03), 2.f); - float16x8_t _tmp5m = vfmsq_n_f16(vfmaq_n_f16(_r05, _r01, 4.f), _r03, 5.f); - - vst1q_f16(tmp[0][m], _tmp0m); - vst1q_f16(tmp[1][m], _tmp1m); - vst1q_f16(tmp[2][m], _tmp2m); - vst1q_f16(tmp[3][m], _tmp3m); - vst1q_f16(tmp[4][m], _tmp4m); - vst1q_f16(tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 6 + j) * 8; - __fp16* r0_tm_1 = r0_tm_0 + tiles * 8; - __fp16* r0_tm_2 = r0_tm_0 + tiles * 16; - __fp16* r0_tm_3 = r0_tm_0 + tiles * 24; - __fp16* r0_tm_4 = r0_tm_0 + tiles * 32; - __fp16* r0_tm_5 = r0_tm_0 + tiles * 40; - - for (int m = 0; m < 6; m++) - { - float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); - float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); - float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); - float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); - float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); - float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); - - float16x8_t _r0tm0 = vfmsq_n_f16(vfmaq_n_f16(_tmp04, _tmp00, 4.f), _tmp02, 5.f); - float16x8_t _r0tm1 = vfmsq_n_f16(vaddq_f16(_tmp04, _tmp03), vaddq_f16(_tmp01, _tmp02), 4.f); - float16x8_t _r0tm2 = vfmaq_n_f16(vsubq_f16(_tmp04, _tmp03), vsubq_f16(_tmp01, _tmp02), 4.f); - float16x8_t _r0tm3 = vfmsq_n_f16(vsubq_f16(_tmp04, _tmp02), vsubq_f16(_tmp01, _tmp03), 2.f); - float16x8_t _r0tm4 = vfmaq_n_f16(vsubq_f16(_tmp04, _tmp02), vsubq_f16(_tmp01, _tmp03), 2.f); - float16x8_t _r0tm5 = vfmsq_n_f16(vfmaq_n_f16(_tmp05, _tmp01, 4.f), _tmp03, 5.f); - - vst1q_f16(r0_tm_0, _r0tm0); - vst1q_f16(r0_tm_1, _r0tm1); - vst1q_f16(r0_tm_2, _r0tm2); - vst1q_f16(r0_tm_3, _r0tm3); - vst1q_f16(r0_tm_4, _r0tm4); - vst1q_f16(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 48; - r0_tm_1 += tiles * 48; - r0_tm_2 += tiles * 48; - r0_tm_3 += tiles * 48; - r0_tm_4 += tiles * 48; - r0_tm_5 += tiles * 48; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_pack8_fp16sa_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -2130,113 +1669,7 @@ static void conv3x3s1_winograd42_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float16x8_t _bias0 = bias ? vld1q_f16((const __fp16*)bias + p * 8) : vdupq_n_f16(0.f); - - __fp16 tmp[4][6][8]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 6 + j) * 8; - const __fp16* output0_tm_1 = output0_tm_0 + tiles * 8; - const __fp16* output0_tm_2 = output0_tm_0 + tiles * 16; - const __fp16* output0_tm_3 = output0_tm_0 + tiles * 24; - const __fp16* output0_tm_4 = output0_tm_0 + tiles * 32; - const __fp16* output0_tm_5 = output0_tm_0 + tiles * 40; - - __fp16* output0 = out0.row<__fp16>(i * 4) + (j * 4) * 8; - - // TODO neon optimize - for (int m = 0; m < 6; m++) - { - float16x8_t _out0tm0 = vld1q_f16(output0_tm_0); - float16x8_t _out0tm1 = vld1q_f16(output0_tm_1); - float16x8_t _out0tm2 = vld1q_f16(output0_tm_2); - float16x8_t _out0tm3 = vld1q_f16(output0_tm_3); - float16x8_t _out0tm4 = vld1q_f16(output0_tm_4); - float16x8_t _out0tm5 = vld1q_f16(output0_tm_5); - - float16x8_t _tmp02a = vaddq_f16(_out0tm1, _out0tm2); - float16x8_t _tmp13a = vsubq_f16(_out0tm1, _out0tm2); - - float16x8_t _tmp02b = vaddq_f16(_out0tm3, _out0tm4); - float16x8_t _tmp13b = vsubq_f16(_out0tm3, _out0tm4); - - float16x8_t _tmp0m = vaddq_f16(vaddq_f16(_out0tm0, _tmp02a), _tmp02b); - float16x8_t _tmp1m = vfmaq_n_f16(_tmp13a, _tmp13b, 2.f); - float16x8_t _tmp2m = vfmaq_n_f16(_tmp02a, _tmp02b, 4.f); - float16x8_t _tmp3m = vfmaq_n_f16(vaddq_f16(_out0tm5, _tmp13a), _tmp13b, 8.f); - - vst1q_f16(tmp[0][m], _tmp0m); - vst1q_f16(tmp[1][m], _tmp1m); - vst1q_f16(tmp[2][m], _tmp2m); - vst1q_f16(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 48; - output0_tm_1 += tiles * 48; - output0_tm_2 += tiles * 48; - output0_tm_3 += tiles * 48; - output0_tm_4 += tiles * 48; - output0_tm_5 += tiles * 48; - } - - for (int m = 0; m < 4; m++) - { - float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); - float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); - float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); - float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); - float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); - float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); - - float16x8_t _tmp02a = vaddq_f16(_tmp01, _tmp02); - float16x8_t _tmp13a = vsubq_f16(_tmp01, _tmp02); - - float16x8_t _tmp02b = vaddq_f16(_tmp03, _tmp04); - float16x8_t _tmp13b = vsubq_f16(_tmp03, _tmp04); - - float16x8_t _out00 = vaddq_f16(_bias0, vaddq_f16(vaddq_f16(_tmp00, _tmp02a), _tmp02b)); - float16x8_t _out01 = vaddq_f16(_bias0, vfmaq_n_f16(_tmp13a, _tmp13b, 2.f)); - float16x8_t _out02 = vaddq_f16(_bias0, vfmaq_n_f16(_tmp02a, _tmp02b, 4.f)); - float16x8_t _out03 = vaddq_f16(_bias0, vfmaq_n_f16(vaddq_f16(_tmp05, _tmp13a), _tmp13b, 8.f)); - - vst1q_f16(output0, _out00); - vst1q_f16(output0 + 8, _out01); - vst1q_f16(output0 + 16, _out02); - vst1q_f16(output0 + 24, _out03); - - output0 += outw * 8; - } - } - } - } + conv3x3s1_winograd42_transform_output_pack8_fp16sa_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_3x3_pack8to1_fp16s.h b/src/layer/arm/convolution_3x3_pack8to1_fp16s.h index 6b0a25a4ca0..3e2a8c9405f 100644 --- a/src/layer/arm/convolution_3x3_pack8to1_fp16s.h +++ b/src/layer/arm/convolution_3x3_pack8to1_fp16s.h @@ -128,12 +128,12 @@ static void conv3x3s1_winograd64_transform_kernel_pack8to1_fp16sa_neon(const Mat } } -static void conv3x3s1_winograd64_pack8to1_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack8to1_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; int inch = bottom_blob.c; - //size_t elemsize = bottom_blob.elemsize; + size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; int outw = top_blob.w; @@ -150,210 +150,15 @@ static void conv3x3s1_winograd64_pack8to1_fp16sa_neon(const Mat& bottom_blob, Ma h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const __fp16* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; - - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - __fp16 tmp[8][8][8]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const __fp16* r0 = img0.row(i * 6) + (j * 6) * 8; - - for (int m = 0; m < 8; m++) - { - float16x8_t _r00 = vld1q_f16(r0); - float16x8_t _r01 = vld1q_f16(r0 + 8); - float16x8_t _r02 = vld1q_f16(r0 + 16); - float16x8_t _r03 = vld1q_f16(r0 + 24); - float16x8_t _r04 = vld1q_f16(r0 + 32); - float16x8_t _r05 = vld1q_f16(r0 + 40); - float16x8_t _r06 = vld1q_f16(r0 + 48); - float16x8_t _r07 = vld1q_f16(r0 + 56); - - float16x8_t _tmp0m = vfmaq_n_f16(vsubq_f16(_r00, _r06), vsubq_f16(_r04, _r02), 5.25f); - float16x8_t _tmp7m = vfmaq_n_f16(vsubq_f16(_r07, _r01), vsubq_f16(_r03, _r05), 5.25f); - vst1q_f16(tmp[0][m], _tmp0m); - vst1q_f16(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_r02, _r06), _r04, 4.25f); - float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); - - float16x8_t _tmp1m = vaddq_f16(_tmp12a, _tmp12b); - float16x8_t _tmp2m = vsubq_f16(_tmp12a, _tmp12b); - vst1q_f16(tmp[1][m], _tmp1m); - vst1q_f16(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_r06, _r02, 0.25f), _r04, 1.25f); - float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float16x8_t _tmp3m = vaddq_f16(_tmp34a, _tmp34b); - float16x8_t _tmp4m = vsubq_f16(_tmp34a, _tmp34b); - vst1q_f16(tmp[3][m], _tmp3m); - vst1q_f16(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float16x8_t _tmp56a = vfmaq_n_f16(_r06, vfmsq_n_f16(_r02, _r04, 1.25f), 4.f); - float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float16x8_t _tmp5m = vaddq_f16(_tmp56a, _tmp56b); - float16x8_t _tmp6m = vsubq_f16(_tmp56a, _tmp56b); - vst1q_f16(tmp[5][m], _tmp5m); - vst1q_f16(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; - - r0 += w * 8; - } + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; - __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 8; - __fp16* r0_tm_1 = r0_tm_0 + tiles * 8; - __fp16* r0_tm_2 = r0_tm_0 + tiles * 16; - __fp16* r0_tm_3 = r0_tm_0 + tiles * 24; - __fp16* r0_tm_4 = r0_tm_0 + tiles * 32; - __fp16* r0_tm_5 = r0_tm_0 + tiles * 40; - __fp16* r0_tm_6 = r0_tm_0 + tiles * 48; - __fp16* r0_tm_7 = r0_tm_0 + tiles * 56; - - for (int m = 0; m < 8; m++) - { - float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); - float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); - float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); - float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); - float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); - float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); - float16x8_t _tmp06 = vld1q_f16(tmp[m][6]); - float16x8_t _tmp07 = vld1q_f16(tmp[m][7]); - - float16x8_t _r0tm0 = vfmaq_n_f16(vsubq_f16(_tmp00, _tmp06), vsubq_f16(_tmp04, _tmp02), 5.25f); - float16x8_t _r0tm7 = vfmaq_n_f16(vsubq_f16(_tmp07, _tmp01), vsubq_f16(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_tmp02, _tmp06), _tmp04, 4.25f); - float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float16x8_t _r0tm1 = vaddq_f16(_tmp12a, _tmp12b); - float16x8_t _r0tm2 = vsubq_f16(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float16x8_t _r0tm3 = vaddq_f16(_tmp34a, _tmp34b); - float16x8_t _r0tm4 = vsubq_f16(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float16x8_t _tmp56a = vfmaq_n_f16(_tmp06, vfmsq_n_f16(_tmp02, _tmp04, 1.25f), 4.f); - float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float16x8_t _r0tm5 = vaddq_f16(_tmp56a, _tmp56b); - float16x8_t _r0tm6 = vsubq_f16(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1q_f16(r0_tm_0, _r0tm0); - vst1q_f16(r0_tm_1, _r0tm1); - vst1q_f16(r0_tm_2, _r0tm2); - vst1q_f16(r0_tm_3, _r0tm3); - vst1q_f16(r0_tm_4, _r0tm4); - vst1q_f16(r0_tm_5, _r0tm5); - vst1q_f16(r0_tm_6, _r0tm6); - vst1q_f16(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 64; - r0_tm_1 += tiles * 64; - r0_tm_2 += tiles * 64; - r0_tm_3 += tiles * 64; - r0_tm_4 += tiles * 64; - r0_tm_5 += tiles * 64; - r0_tm_6 += tiles * 64; - r0_tm_7 += tiles * 64; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack8_fp16sa_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1010,111 +815,7 @@ static void conv3x3s1_winograd64_pack8to1_fp16sa_neon(const Mat& bottom_blob, Ma top_blob_bordered.create(outw, outh, outch, 2u, 1, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - const __fp16 bias0 = bias ? bias[p] : 0.f; - // float32x2_t _bias0 = vdup_n_f32(bias0); - - __fp16 tmp[6][8]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, 4u, 1, opt.workspace_allocator); - - const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 1; - const __fp16* output0_tm_1 = output0_tm_0 + tiles * 1; - const __fp16* output0_tm_2 = output0_tm_0 + tiles * 2; - const __fp16* output0_tm_3 = output0_tm_0 + tiles * 3; - const __fp16* output0_tm_4 = output0_tm_0 + tiles * 4; - const __fp16* output0_tm_5 = output0_tm_0 + tiles * 5; - const __fp16* output0_tm_6 = output0_tm_0 + tiles * 6; - const __fp16* output0_tm_7 = output0_tm_0 + tiles * 7; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - __fp16 tmp024a = output0_tm_1[0] + output0_tm_2[0]; - __fp16 tmp135a = output0_tm_1[0] - output0_tm_2[0]; - - __fp16 tmp024b = output0_tm_3[0] + output0_tm_4[0]; - __fp16 tmp135b = output0_tm_3[0] - output0_tm_4[0]; - - __fp16 tmp024c = output0_tm_5[0] + output0_tm_6[0]; - __fp16 tmp135c = output0_tm_5[0] - output0_tm_6[0]; - - tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; - tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 8; - output0_tm_1 += tiles * 8; - output0_tm_2 += tiles * 8; - output0_tm_3 += tiles * 8; - output0_tm_4 += tiles * 8; - output0_tm_5 += tiles * 8; - output0_tm_6 += tiles * 8; - output0_tm_7 += tiles * 8; - } - - __fp16* output0 = out0.row<__fp16>(i * 6) + j * 6; - - for (int m = 0; m < 6; m++) - { - const __fp16* tmp0 = tmp[m]; - - __fp16 tmp024a = tmp0[1] + tmp0[2]; - __fp16 tmp135a = tmp0[1] - tmp0[2]; - - __fp16 tmp024b = tmp0[3] + tmp0[4]; - __fp16 tmp135b = tmp0[3] - tmp0[4]; - - __fp16 tmp024c = tmp0[5] + tmp0[6]; - __fp16 tmp135c = tmp0[5] - tmp0[6]; - - output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw; - } - } - } - } + conv3x3s1_winograd64_transform_output_fp16sa_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_3x3_pack8to4_fp16s.h b/src/layer/arm/convolution_3x3_pack8to4_fp16s.h index 0382038cf91..a59bfcf58b4 100644 --- a/src/layer/arm/convolution_3x3_pack8to4_fp16s.h +++ b/src/layer/arm/convolution_3x3_pack8to4_fp16s.h @@ -134,12 +134,12 @@ static void conv3x3s1_winograd64_transform_kernel_pack8to4_fp16sa_neon(const Mat } } -static void conv3x3s1_winograd64_pack8to4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack8to4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; int inch = bottom_blob.c; - //size_t elemsize = bottom_blob.elemsize; + size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; int outw = top_blob.w; @@ -156,210 +156,15 @@ static void conv3x3s1_winograd64_pack8to4_fp16sa_neon(const Mat& bottom_blob, Ma h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const __fp16* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; - - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - __fp16 tmp[8][8][8]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const __fp16* r0 = img0.row(i * 6) + (j * 6) * 8; - - for (int m = 0; m < 8; m++) - { - float16x8_t _r00 = vld1q_f16(r0); - float16x8_t _r01 = vld1q_f16(r0 + 8); - float16x8_t _r02 = vld1q_f16(r0 + 16); - float16x8_t _r03 = vld1q_f16(r0 + 24); - float16x8_t _r04 = vld1q_f16(r0 + 32); - float16x8_t _r05 = vld1q_f16(r0 + 40); - float16x8_t _r06 = vld1q_f16(r0 + 48); - float16x8_t _r07 = vld1q_f16(r0 + 56); - - float16x8_t _tmp0m = vfmaq_n_f16(vsubq_f16(_r00, _r06), vsubq_f16(_r04, _r02), 5.25f); - float16x8_t _tmp7m = vfmaq_n_f16(vsubq_f16(_r07, _r01), vsubq_f16(_r03, _r05), 5.25f); - vst1q_f16(tmp[0][m], _tmp0m); - vst1q_f16(tmp[7][m], _tmp7m); - - // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25; - // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25; - - float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_r02, _r06), _r04, 4.25f); - float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_r01, _r05), _r03, 4.25f); - - // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25); - // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25); - - float16x8_t _tmp1m = vaddq_f16(_tmp12a, _tmp12b); - float16x8_t _tmp2m = vsubq_f16(_tmp12a, _tmp12b); - vst1q_f16(tmp[1][m], _tmp1m); - vst1q_f16(tmp[2][m], _tmp2m); - - // tmp[1][m] = tmp12a + tmp12b; - // tmp[2][m] = tmp12a - tmp12b; - - float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_r06, _r02, 0.25f), _r04, 1.25f); - float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); - - // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25); - // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2); - - float16x8_t _tmp3m = vaddq_f16(_tmp34a, _tmp34b); - float16x8_t _tmp4m = vsubq_f16(_tmp34a, _tmp34b); - vst1q_f16(tmp[3][m], _tmp3m); - vst1q_f16(tmp[4][m], _tmp4m); - - // tmp[3][m] = tmp34a + tmp34b; - // tmp[4][m] = tmp34a - tmp34b; - - float16x8_t _tmp56a = vfmaq_n_f16(_r06, vfmsq_n_f16(_r02, _r04, 1.25f), 4.f); - float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); - - // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4); - // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5); - - float16x8_t _tmp5m = vaddq_f16(_tmp56a, _tmp56b); - float16x8_t _tmp6m = vsubq_f16(_tmp56a, _tmp56b); - vst1q_f16(tmp[5][m], _tmp5m); - vst1q_f16(tmp[6][m], _tmp6m); - - // tmp[5][m] = tmp56a + tmp56b; - // tmp[6][m] = tmp56a - tmp56b; + int w_tiles = outw / 6 * 8; + int h_tiles = outh / 6 * 8; + const int tiles = w_tiles * h_tiles; - r0 += w * 8; - } - - __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 8; - __fp16* r0_tm_1 = r0_tm_0 + tiles * 8; - __fp16* r0_tm_2 = r0_tm_0 + tiles * 16; - __fp16* r0_tm_3 = r0_tm_0 + tiles * 24; - __fp16* r0_tm_4 = r0_tm_0 + tiles * 32; - __fp16* r0_tm_5 = r0_tm_0 + tiles * 40; - __fp16* r0_tm_6 = r0_tm_0 + tiles * 48; - __fp16* r0_tm_7 = r0_tm_0 + tiles * 56; - - for (int m = 0; m < 8; m++) - { - float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); - float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); - float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); - float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); - float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); - float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); - float16x8_t _tmp06 = vld1q_f16(tmp[m][6]); - float16x8_t _tmp07 = vld1q_f16(tmp[m][7]); - - float16x8_t _r0tm0 = vfmaq_n_f16(vsubq_f16(_tmp00, _tmp06), vsubq_f16(_tmp04, _tmp02), 5.25f); - float16x8_t _r0tm7 = vfmaq_n_f16(vsubq_f16(_tmp07, _tmp01), vsubq_f16(_tmp03, _tmp05), 5.25f); - - // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25; - // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25; - - float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_tmp02, _tmp06), _tmp04, 4.25f); - float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_tmp01, _tmp05), _tmp03, 4.25f); - - // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25); - // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25); - - float16x8_t _r0tm1 = vaddq_f16(_tmp12a, _tmp12b); - float16x8_t _r0tm2 = vsubq_f16(_tmp12a, _tmp12b); - - // r0_tm[1] = tmp12a + tmp12b; - // r0_tm[2] = tmp12a - tmp12b; - - float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); - float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); - - // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25); - // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2); - - float16x8_t _r0tm3 = vaddq_f16(_tmp34a, _tmp34b); - float16x8_t _r0tm4 = vsubq_f16(_tmp34a, _tmp34b); - - // r0_tm[3] = tmp34a + tmp34b; - // r0_tm[4] = tmp34a - tmp34b; - - float16x8_t _tmp56a = vfmaq_n_f16(_tmp06, vfmsq_n_f16(_tmp02, _tmp04, 1.25f), 4.f); - float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); - - // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4); - // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5); - - float16x8_t _r0tm5 = vaddq_f16(_tmp56a, _tmp56b); - float16x8_t _r0tm6 = vsubq_f16(_tmp56a, _tmp56b); - - // r0_tm[5] = tmp56a + tmp56b; - // r0_tm[6] = tmp56a - tmp56b; - - vst1q_f16(r0_tm_0, _r0tm0); - vst1q_f16(r0_tm_1, _r0tm1); - vst1q_f16(r0_tm_2, _r0tm2); - vst1q_f16(r0_tm_3, _r0tm3); - vst1q_f16(r0_tm_4, _r0tm4); - vst1q_f16(r0_tm_5, _r0tm5); - vst1q_f16(r0_tm_6, _r0tm6); - vst1q_f16(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 64; - r0_tm_1 += tiles * 64; - r0_tm_2 += tiles * 64; - r0_tm_3 += tiles * 64; - r0_tm_4 += tiles * 64; - r0_tm_5 += tiles * 64; - r0_tm_6 += tiles * 64; - r0_tm_7 += tiles * 64; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack8_fp16sa_neon(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1035,173 +840,7 @@ static void conv3x3s1_winograd64_pack8to4_fp16sa_neon(const Mat& bottom_blob, Ma top_blob_bordered.create(outw, outh, outch, 2u * 4, 4, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - float16x4_t _bias0 = bias ? vld1_f16((const __fp16*)bias + p * 4) : vdup_n_f16(0.f); - - __fp16 tmp[6][8][4]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 4; - const __fp16* output0_tm_1 = output0_tm_0 + tiles * 4; - const __fp16* output0_tm_2 = output0_tm_0 + tiles * 8; - const __fp16* output0_tm_3 = output0_tm_0 + tiles * 12; - const __fp16* output0_tm_4 = output0_tm_0 + tiles * 16; - const __fp16* output0_tm_5 = output0_tm_0 + tiles * 20; - const __fp16* output0_tm_6 = output0_tm_0 + tiles * 24; - const __fp16* output0_tm_7 = output0_tm_0 + tiles * 28; - - __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 4; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - float16x4_t _out0tm0 = vld1_f16(output0_tm_0); - float16x4_t _out0tm1 = vld1_f16(output0_tm_1); - float16x4_t _out0tm2 = vld1_f16(output0_tm_2); - float16x4_t _out0tm3 = vld1_f16(output0_tm_3); - float16x4_t _out0tm4 = vld1_f16(output0_tm_4); - float16x4_t _out0tm5 = vld1_f16(output0_tm_5); - float16x4_t _out0tm6 = vld1_f16(output0_tm_6); - float16x4_t _out0tm7 = vld1_f16(output0_tm_7); - - float16x4_t _tmp024a = vadd_f16(_out0tm1, _out0tm2); - float16x4_t _tmp135a = vsub_f16(_out0tm1, _out0tm2); - - // float tmp024a = output0_tm[1] + output0_tm[2]; - // float tmp135a = output0_tm[1] - output0_tm[2]; - - float16x4_t _tmp024b = vadd_f16(_out0tm3, _out0tm4); - float16x4_t _tmp135b = vsub_f16(_out0tm3, _out0tm4); - - // float tmp024b = output0_tm[3] + output0_tm[4]; - // float tmp135b = output0_tm[3] - output0_tm[4]; - - float16x4_t _tmp024c = vadd_f16(_out0tm5, _out0tm6); - float16x4_t _tmp135c = vsub_f16(_out0tm5, _out0tm6); - - // float tmp024c = output0_tm[5] + output0_tm[6]; - // float tmp135c = output0_tm[5] - output0_tm[6]; - - float16x4_t _tmp0m = vadd_f16(vadd_f16(_out0tm0, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f)); - float16x4_t _tmp2m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); - float16x4_t _tmp4m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); - vst1_f16(tmp[0][m], _tmp0m); - vst1_f16(tmp[2][m], _tmp2m); - vst1_f16(tmp[4][m], _tmp4m); - - // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32; - // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float16x4_t _tmp1m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); - float16x4_t _tmp3m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); - float16x4_t _tmp5m = vadd_f16(vadd_f16(_out0tm7, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f)); - vst1_f16(tmp[1][m], _tmp1m); - vst1_f16(tmp[3][m], _tmp3m); - vst1_f16(tmp[5][m], _tmp5m); - - // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 32; - output0_tm_1 += tiles * 32; - output0_tm_2 += tiles * 32; - output0_tm_3 += tiles * 32; - output0_tm_4 += tiles * 32; - output0_tm_5 += tiles * 32; - output0_tm_6 += tiles * 32; - output0_tm_7 += tiles * 32; - } - - for (int m = 0; m < 6; m++) - { - float16x4_t _tmp00 = vld1_f16(tmp[m][0]); - float16x4_t _tmp01 = vld1_f16(tmp[m][1]); - float16x4_t _tmp02 = vld1_f16(tmp[m][2]); - float16x4_t _tmp03 = vld1_f16(tmp[m][3]); - float16x4_t _tmp04 = vld1_f16(tmp[m][4]); - float16x4_t _tmp05 = vld1_f16(tmp[m][5]); - float16x4_t _tmp06 = vld1_f16(tmp[m][6]); - float16x4_t _tmp07 = vld1_f16(tmp[m][7]); - - float16x4_t _tmp024a = vadd_f16(_tmp01, _tmp02); - float16x4_t _tmp135a = vsub_f16(_tmp01, _tmp02); - - // float tmp024a = tmp0[1] + tmp0[2]; - // float tmp135a = tmp0[1] - tmp0[2]; - - float16x4_t _tmp024b = vadd_f16(_tmp03, _tmp04); - float16x4_t _tmp135b = vsub_f16(_tmp03, _tmp04); - - // float tmp024b = tmp0[3] + tmp0[4]; - // float tmp135b = tmp0[3] - tmp0[4]; - - float16x4_t _tmp024c = vadd_f16(_tmp05, _tmp06); - float16x4_t _tmp135c = vsub_f16(_tmp05, _tmp06); - - // float tmp024c = tmp0[5] + tmp0[6]; - // float tmp135c = tmp0[5] - tmp0[6]; - - float16x4_t _out00 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp00, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f))); - float16x4_t _out02 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); - float16x4_t _out04 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); - vst1_f16(output0, _out00); - vst1_f16(output0 + 8, _out02); - vst1_f16(output0 + 16, _out04); - - // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - float16x4_t _out01 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); - float16x4_t _out03 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); - float16x4_t _out05 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp07, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f))); - vst1_f16(output0 + 4, _out01); - vst1_f16(output0 + 12, _out03); - vst1_f16(output0 + 20, _out05); - - // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack4_fp16sa_neon(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/arm/convolution_arm.cpp b/src/layer/arm/convolution_arm.cpp index 972f453236b..876d9e55da8 100644 --- a/src/layer/arm/convolution_arm.cpp +++ b/src/layer/arm/convolution_arm.cpp @@ -28,6 +28,7 @@ namespace ncnn { #include "convolution_sgemm.h" +#include "convolution_winograd_transform.h" #include "convolution_1x1.h" #include "convolution_2x2.h" @@ -39,6 +40,7 @@ namespace ncnn { #if NCNN_BF16 #include "convolution_bf16s.h" #include "convolution_sgemm_bf16s.h" +#include "convolution_winograd_transform_bf16s.h" #include "convolution_1x1_bf16s.h" #endif // NCNN_BF16 @@ -56,6 +58,7 @@ namespace ncnn { #include "convolution_sgemm_pack4.h" #include "convolution_sgemm_pack1to4.h" #include "convolution_sgemm_pack4to1.h" +#include "convolution_winograd_transform_pack4.h" #include "convolution_1x1_pack4.h" #include "convolution_1x1_pack1to4.h" #include "convolution_1x1_pack4to1.h" @@ -72,6 +75,7 @@ namespace ncnn { #include "convolution_sgemm_pack4_bf16s.h" #include "convolution_sgemm_pack1to4_bf16s.h" #include "convolution_sgemm_pack4to1_bf16s.h" +#include "convolution_winograd_transform_pack4_bf16s.h" #include "convolution_1x1_pack4_bf16s.h" #include "convolution_1x1_pack1to4_bf16s.h" #include "convolution_1x1_pack4to1_bf16s.h" @@ -115,6 +119,9 @@ namespace ncnn { #include "convolution_sgemm_pack8_fp16s.h" #include "convolution_sgemm_pack8to4_fp16s.h" #include "convolution_sgemm_pack8to1_fp16s.h" +#include "convolution_winograd_transform_fp16s.h" +#include "convolution_winograd_transform_pack4_fp16s.h" +#include "convolution_winograd_transform_pack8_fp16s.h" #include "convolution_1x1_fp16s.h" #include "convolution_1x1_pack4_fp16s.h" #include "convolution_1x1_pack1to4_fp16s.h" diff --git a/src/layer/arm/convolution_winograd_transform.h b/src/layer/arm/convolution_winograd_transform.h new file mode 100644 index 00000000000..5b0d0c18c20 --- /dev/null +++ b/src/layer/arm/convolution_winograd_transform.h @@ -0,0 +1,125 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_output_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + const float bias0 = biasptr ? biasptr[p] : 0.f; + + float tmp[6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j); + const float* output0_tm_1 = output0_tm_0 + tiles * 1; + const float* output0_tm_2 = output0_tm_0 + tiles * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 5; + const float* output0_tm_6 = output0_tm_0 + tiles * 6; + const float* output0_tm_7 = output0_tm_0 + tiles * 7; + + // TODO neon optimize + for (int m = 0; m < 8; m++) + { + float tmp024a = output0_tm_1[0] + output0_tm_2[0]; + float tmp135a = output0_tm_1[0] - output0_tm_2[0]; + + float tmp024b = output0_tm_3[0] + output0_tm_4[0]; + float tmp135b = output0_tm_3[0] - output0_tm_4[0]; + + float tmp024c = output0_tm_5[0] + output0_tm_6[0]; + float tmp135c = output0_tm_5[0] - output0_tm_6[0]; + + tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; + tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; + tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; + + tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; + tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; + tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; + + output0_tm_0 += tiles * 8; + output0_tm_1 += tiles * 8; + output0_tm_2 += tiles * 8; + output0_tm_3 += tiles * 8; + output0_tm_4 += tiles * 8; + output0_tm_5 += tiles * 8; + output0_tm_6 += tiles * 8; + output0_tm_7 += tiles * 8; + } + + float* output0 = out0.row(i * 6) + j * 6; + + for (int m = 0; m < 6; m++) + { + const float* tmp0 = tmp[m]; + + float tmp024a = tmp0[1] + tmp0[2]; + float tmp135a = tmp0[1] - tmp0[2]; + + float tmp024b = tmp0[3] + tmp0[4]; + float tmp135b = tmp0[3] - tmp0[4]; + + float tmp024c = tmp0[5] + tmp0[6]; + float tmp135c = tmp0[5] - tmp0[6]; + + output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; + output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; + output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; + + output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; + output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; + output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; + + output0 += outw; + } + } + } + } +} diff --git a/src/layer/arm/convolution_winograd_transform_bf16s.h b/src/layer/arm/convolution_winograd_transform_bf16s.h new file mode 100644 index 00000000000..5595f2c52e0 --- /dev/null +++ b/src/layer/arm/convolution_winograd_transform_bf16s.h @@ -0,0 +1,125 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_output_bf16s_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + const float bias0 = biasptr ? biasptr[p] : 0.f; + + float tmp[6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j); + const float* output0_tm_1 = output0_tm_0 + tiles * 1; + const float* output0_tm_2 = output0_tm_0 + tiles * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 5; + const float* output0_tm_6 = output0_tm_0 + tiles * 6; + const float* output0_tm_7 = output0_tm_0 + tiles * 7; + + // TODO neon optimize + for (int m = 0; m < 8; m++) + { + float tmp024a = output0_tm_1[0] + output0_tm_2[0]; + float tmp135a = output0_tm_1[0] - output0_tm_2[0]; + + float tmp024b = output0_tm_3[0] + output0_tm_4[0]; + float tmp135b = output0_tm_3[0] - output0_tm_4[0]; + + float tmp024c = output0_tm_5[0] + output0_tm_6[0]; + float tmp135c = output0_tm_5[0] - output0_tm_6[0]; + + tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; + tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; + tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; + + tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; + tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; + tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; + + output0_tm_0 += tiles * 8; + output0_tm_1 += tiles * 8; + output0_tm_2 += tiles * 8; + output0_tm_3 += tiles * 8; + output0_tm_4 += tiles * 8; + output0_tm_5 += tiles * 8; + output0_tm_6 += tiles * 8; + output0_tm_7 += tiles * 8; + } + + unsigned short* output0 = out0.row(i * 6) + j * 6; + + for (int m = 0; m < 6; m++) + { + const float* tmp0 = tmp[m]; + + float tmp024a = tmp0[1] + tmp0[2]; + float tmp135a = tmp0[1] - tmp0[2]; + + float tmp024b = tmp0[3] + tmp0[4]; + float tmp135b = tmp0[3] - tmp0[4]; + + float tmp024c = tmp0[5] + tmp0[6]; + float tmp135c = tmp0[5] - tmp0[6]; + + output0[0] = float32_to_bfloat16(bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32); + output0[2] = float32_to_bfloat16(bias0 + tmp024a + tmp024b * 4 + tmp024c * 8); + output0[4] = float32_to_bfloat16(bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c); + + output0[1] = float32_to_bfloat16(bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16); + output0[3] = float32_to_bfloat16(bias0 + tmp135a + tmp135b * 8 + tmp135c * 4); + output0[5] = float32_to_bfloat16(bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c); + + output0 += outw; + } + } + } + } +} diff --git a/src/layer/arm/convolution_winograd_transform_fp16s.h b/src/layer/arm/convolution_winograd_transform_fp16s.h new file mode 100644 index 00000000000..0a05bc26704 --- /dev/null +++ b/src/layer/arm/convolution_winograd_transform_fp16s.h @@ -0,0 +1,125 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_output_fp16sa_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const __fp16* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + const __fp16 bias0 = biasptr ? biasptr[p] : 0.f; + + __fp16 tmp[6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tiles + j) * 1; + const __fp16* output0_tm_1 = output0_tm_0 + tiles * 1; + const __fp16* output0_tm_2 = output0_tm_0 + tiles * 2; + const __fp16* output0_tm_3 = output0_tm_0 + tiles * 3; + const __fp16* output0_tm_4 = output0_tm_0 + tiles * 4; + const __fp16* output0_tm_5 = output0_tm_0 + tiles * 5; + const __fp16* output0_tm_6 = output0_tm_0 + tiles * 6; + const __fp16* output0_tm_7 = output0_tm_0 + tiles * 7; + + // TODO neon optimize + for (int m = 0; m < 8; m++) + { + __fp16 tmp024a = output0_tm_1[0] + output0_tm_2[0]; + __fp16 tmp135a = output0_tm_1[0] - output0_tm_2[0]; + + __fp16 tmp024b = output0_tm_3[0] + output0_tm_4[0]; + __fp16 tmp135b = output0_tm_3[0] - output0_tm_4[0]; + + __fp16 tmp024c = output0_tm_5[0] + output0_tm_6[0]; + __fp16 tmp135c = output0_tm_5[0] - output0_tm_6[0]; + + tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; + tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; + tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; + + tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; + tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; + tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; + + output0_tm_0 += tiles * 8; + output0_tm_1 += tiles * 8; + output0_tm_2 += tiles * 8; + output0_tm_3 += tiles * 8; + output0_tm_4 += tiles * 8; + output0_tm_5 += tiles * 8; + output0_tm_6 += tiles * 8; + output0_tm_7 += tiles * 8; + } + + __fp16* output0 = out0.row<__fp16>(i * 6) + j * 6; + + for (int m = 0; m < 6; m++) + { + const __fp16* tmp0 = tmp[m]; + + __fp16 tmp024a = tmp0[1] + tmp0[2]; + __fp16 tmp135a = tmp0[1] - tmp0[2]; + + __fp16 tmp024b = tmp0[3] + tmp0[4]; + __fp16 tmp135b = tmp0[3] - tmp0[4]; + + __fp16 tmp024c = tmp0[5] + tmp0[6]; + __fp16 tmp135c = tmp0[5] - tmp0[6]; + + output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; + output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; + output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; + + output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; + output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; + output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; + + output0 += outw; + } + } + } + } +} diff --git a/src/layer/arm/convolution_winograd_transform_pack4.h b/src/layer/arm/convolution_winograd_transform_pack4.h new file mode 100644 index 00000000000..a3bd7640eec --- /dev/null +++ b/src/layer/arm/convolution_winograd_transform_pack4.h @@ -0,0 +1,535 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack4_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + float tmp[8][8][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + float32x4_t _r00 = vld1q_f32(r0); + float32x4_t _r01 = vld1q_f32(r0 + 4); + float32x4_t _r02 = vld1q_f32(r0 + 8); + float32x4_t _r03 = vld1q_f32(r0 + 12); + float32x4_t _r04 = vld1q_f32(r0 + 16); + float32x4_t _r05 = vld1q_f32(r0 + 20); + float32x4_t _r06 = vld1q_f32(r0 + 24); + float32x4_t _r07 = vld1q_f32(r0 + 28); + + float32x4_t _tmp0m = vmlaq_n_f32(vsubq_f32(_r00, _r06), vsubq_f32(_r04, _r02), 5.25f); + float32x4_t _tmp7m = vmlaq_n_f32(vsubq_f32(_r07, _r01), vsubq_f32(_r03, _r05), 5.25f); + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[7][m], _tmp7m); + + float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_r02, _r06), _r04, 4.25f); + float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_r01, _r05), _r03, 4.25f); + + float32x4_t _tmp1m = vaddq_f32(_tmp12a, _tmp12b); + float32x4_t _tmp2m = vsubq_f32(_tmp12a, _tmp12b); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[2][m], _tmp2m); + + float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_r06, _r02, 0.25f), _r04, 1.25f); + float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); + + float32x4_t _tmp3m = vaddq_f32(_tmp34a, _tmp34b); + float32x4_t _tmp4m = vsubq_f32(_tmp34a, _tmp34b); + vst1q_f32(tmp[3][m], _tmp3m); + vst1q_f32(tmp[4][m], _tmp4m); + + float32x4_t _tmp56a = vmlaq_n_f32(_r06, vmlsq_n_f32(_r02, _r04, 1.25f), 4.f); + float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); + + float32x4_t _tmp5m = vaddq_f32(_tmp56a, _tmp56b); + float32x4_t _tmp6m = vsubq_f32(_tmp56a, _tmp56b); + vst1q_f32(tmp[5][m], _tmp5m); + vst1q_f32(tmp[6][m], _tmp6m); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 8; + float* r0_tm_3 = r0_tm_0 + tiles * 12; + float* r0_tm_4 = r0_tm_0 + tiles * 16; + float* r0_tm_5 = r0_tm_0 + tiles * 20; + float* r0_tm_6 = r0_tm_0 + tiles * 24; + float* r0_tm_7 = r0_tm_0 + tiles * 28; + + for (int m = 0; m < 8; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); + float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); + + float32x4_t _r0tm0 = vmlaq_n_f32(vsubq_f32(_tmp00, _tmp06), vsubq_f32(_tmp04, _tmp02), 5.25f); + float32x4_t _r0tm7 = vmlaq_n_f32(vsubq_f32(_tmp07, _tmp01), vsubq_f32(_tmp03, _tmp05), 5.25f); + + float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_tmp02, _tmp06), _tmp04, 4.25f); + float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_tmp01, _tmp05), _tmp03, 4.25f); + + float32x4_t _r0tm1 = vaddq_f32(_tmp12a, _tmp12b); + float32x4_t _r0tm2 = vsubq_f32(_tmp12a, _tmp12b); + + float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); + float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); + + float32x4_t _r0tm3 = vaddq_f32(_tmp34a, _tmp34b); + float32x4_t _r0tm4 = vsubq_f32(_tmp34a, _tmp34b); + + float32x4_t _tmp56a = vmlaq_n_f32(_tmp06, vmlsq_n_f32(_tmp02, _tmp04, 1.25f), 4.f); + float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); + + float32x4_t _r0tm5 = vaddq_f32(_tmp56a, _tmp56b); + float32x4_t _r0tm6 = vsubq_f32(_tmp56a, _tmp56b); + + vst1q_f32(r0_tm_0, _r0tm0); + vst1q_f32(r0_tm_1, _r0tm1); + vst1q_f32(r0_tm_2, _r0tm2); + vst1q_f32(r0_tm_3, _r0tm3); + vst1q_f32(r0_tm_4, _r0tm4); + vst1q_f32(r0_tm_5, _r0tm5); + vst1q_f32(r0_tm_6, _r0tm6); + vst1q_f32(r0_tm_7, _r0tm7); + + r0_tm_0 += tiles * 32; + r0_tm_1 += tiles * 32; + r0_tm_2 += tiles * 32; + r0_tm_3 += tiles * 32; + r0_tm_4 += tiles * 32; + r0_tm_5 += tiles * 32; + r0_tm_6 += tiles * 32; + r0_tm_7 += tiles * 32; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack4_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + float32x4_t _bias0 = biasptr ? vld1q_f32(biasptr + p * 4) : vdupq_n_f32(0.f); + + float tmp[6][8][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 8; + const float* output0_tm_3 = output0_tm_0 + tiles * 12; + const float* output0_tm_4 = output0_tm_0 + tiles * 16; + const float* output0_tm_5 = output0_tm_0 + tiles * 20; + const float* output0_tm_6 = output0_tm_0 + tiles * 24; + const float* output0_tm_7 = output0_tm_0 + tiles * 28; + + float* output0 = out0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); + float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); + float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); + float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); + float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); + float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); + float32x4_t _out0tm6 = vld1q_f32(output0_tm_6); + float32x4_t _out0tm7 = vld1q_f32(output0_tm_7); + + float32x4_t _tmp024a = vaddq_f32(_out0tm1, _out0tm2); + float32x4_t _tmp135a = vsubq_f32(_out0tm1, _out0tm2); + + float32x4_t _tmp024b = vaddq_f32(_out0tm3, _out0tm4); + float32x4_t _tmp135b = vsubq_f32(_out0tm3, _out0tm4); + + float32x4_t _tmp024c = vaddq_f32(_out0tm5, _out0tm6); + float32x4_t _tmp135c = vsubq_f32(_out0tm5, _out0tm6); + + float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f)); + float32x4_t _tmp2m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); + float32x4_t _tmp4m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[2][m], _tmp2m); + vst1q_f32(tmp[4][m], _tmp4m); + + float32x4_t _tmp1m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); + float32x4_t _tmp3m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); + float32x4_t _tmp5m = vaddq_f32(vaddq_f32(_out0tm7, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f)); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[3][m], _tmp3m); + vst1q_f32(tmp[5][m], _tmp5m); + + output0_tm_0 += tiles * 32; + output0_tm_1 += tiles * 32; + output0_tm_2 += tiles * 32; + output0_tm_3 += tiles * 32; + output0_tm_4 += tiles * 32; + output0_tm_5 += tiles * 32; + output0_tm_6 += tiles * 32; + output0_tm_7 += tiles * 32; + } + + for (int m = 0; m < 6; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); + float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); + + float32x4_t _tmp024a = vaddq_f32(_tmp01, _tmp02); + float32x4_t _tmp135a = vsubq_f32(_tmp01, _tmp02); + + float32x4_t _tmp024b = vaddq_f32(_tmp03, _tmp04); + float32x4_t _tmp135b = vsubq_f32(_tmp03, _tmp04); + + float32x4_t _tmp024c = vaddq_f32(_tmp05, _tmp06); + float32x4_t _tmp135c = vsubq_f32(_tmp05, _tmp06); + + float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f))); + float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); + float32x4_t _out04 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); + vst1q_f32(output0, _out00); + vst1q_f32(output0 + 8, _out02); + vst1q_f32(output0 + 16, _out04); + + float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); + float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); + float32x4_t _out05 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp07, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f))); + vst1q_f32(output0 + 4, _out01); + vst1q_f32(output0 + 12, _out03); + vst1q_f32(output0 + 20, _out05); + + output0 += outw * 4; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_pack4_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[6][6] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + float tmp[6][6][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + float32x4_t _r00 = vld1q_f32(r0); + float32x4_t _r01 = vld1q_f32(r0 + 4); + float32x4_t _r02 = vld1q_f32(r0 + 8); + float32x4_t _r03 = vld1q_f32(r0 + 12); + float32x4_t _r04 = vld1q_f32(r0 + 16); + float32x4_t _r05 = vld1q_f32(r0 + 20); + + float32x4_t _tmp0m = vmlsq_n_f32(vmlaq_n_f32(_r04, _r00, 4.f), _r02, 5.f); + float32x4_t _tmp1m = vmlsq_n_f32(vaddq_f32(_r04, _r03), vaddq_f32(_r01, _r02), 4.f); + float32x4_t _tmp2m = vmlaq_n_f32(vsubq_f32(_r04, _r03), vsubq_f32(_r01, _r02), 4.f); + float32x4_t _tmp3m = vmlsq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); + float32x4_t _tmp4m = vmlaq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); + float32x4_t _tmp5m = vmlsq_n_f32(vmlaq_n_f32(_r05, _r01, 4.f), _r03, 5.f); + + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[2][m], _tmp2m); + vst1q_f32(tmp[3][m], _tmp3m); + vst1q_f32(tmp[4][m], _tmp4m); + vst1q_f32(tmp[5][m], _tmp5m); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 8; + float* r0_tm_3 = r0_tm_0 + tiles * 12; + float* r0_tm_4 = r0_tm_0 + tiles * 16; + float* r0_tm_5 = r0_tm_0 + tiles * 20; + + for (int m = 0; m < 6; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + + float32x4_t _r0tm0 = vmlsq_n_f32(vmlaq_n_f32(_tmp04, _tmp00, 4.f), _tmp02, 5.f); + float32x4_t _r0tm1 = vmlsq_n_f32(vaddq_f32(_tmp04, _tmp03), vaddq_f32(_tmp01, _tmp02), 4.f); + float32x4_t _r0tm2 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp03), vsubq_f32(_tmp01, _tmp02), 4.f); + float32x4_t _r0tm3 = vmlsq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); + float32x4_t _r0tm4 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); + float32x4_t _r0tm5 = vmlsq_n_f32(vmlaq_n_f32(_tmp05, _tmp01, 4.f), _tmp03, 5.f); + + vst1q_f32(r0_tm_0, _r0tm0); + vst1q_f32(r0_tm_1, _r0tm1); + vst1q_f32(r0_tm_2, _r0tm2); + vst1q_f32(r0_tm_3, _r0tm3); + vst1q_f32(r0_tm_4, _r0tm4); + vst1q_f32(r0_tm_5, _r0tm5); + + r0_tm_0 += tiles * 24; + r0_tm_1 += tiles * 24; + r0_tm_2 += tiles * 24; + r0_tm_3 += tiles * 24; + r0_tm_4 += tiles * 24; + r0_tm_5 += tiles * 24; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_pack4_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + float32x4_t _bias0 = biasptr ? vld1q_f32(biasptr + p * 4) : vdupq_n_f32(0.f); + + float tmp[4][6][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 8; + const float* output0_tm_3 = output0_tm_0 + tiles * 12; + const float* output0_tm_4 = output0_tm_0 + tiles * 16; + const float* output0_tm_5 = output0_tm_0 + tiles * 20; + + float* output0 = out0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); + float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); + float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); + float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); + float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); + float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); + + float32x4_t _tmp02a = vaddq_f32(_out0tm1, _out0tm2); + float32x4_t _tmp13a = vsubq_f32(_out0tm1, _out0tm2); + + float32x4_t _tmp02b = vaddq_f32(_out0tm3, _out0tm4); + float32x4_t _tmp13b = vsubq_f32(_out0tm3, _out0tm4); + + float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp02a), _tmp02b); + float32x4_t _tmp1m = vmlaq_n_f32(_tmp13a, _tmp13b, 2.f); + float32x4_t _tmp2m = vmlaq_n_f32(_tmp02a, _tmp02b, 4.f); + float32x4_t _tmp3m = vmlaq_n_f32(vaddq_f32(_out0tm5, _tmp13a), _tmp13b, 8.f); + + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[2][m], _tmp2m); + vst1q_f32(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 24; + output0_tm_1 += tiles * 24; + output0_tm_2 += tiles * 24; + output0_tm_3 += tiles * 24; + output0_tm_4 += tiles * 24; + output0_tm_5 += tiles * 24; + } + + for (int m = 0; m < 4; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + + float32x4_t _tmp02a = vaddq_f32(_tmp01, _tmp02); + float32x4_t _tmp13a = vsubq_f32(_tmp01, _tmp02); + + float32x4_t _tmp02b = vaddq_f32(_tmp03, _tmp04); + float32x4_t _tmp13b = vsubq_f32(_tmp03, _tmp04); + + float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp02a), _tmp02b)); + float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp13a, _tmp13b, 2.f)); + float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp02a, _tmp02b, 4.f)); + float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vaddq_f32(_tmp05, _tmp13a), _tmp13b, 8.f)); + + vst1q_f32(output0, _out00); + vst1q_f32(output0 + 4, _out01); + vst1q_f32(output0 + 8, _out02); + vst1q_f32(output0 + 12, _out03); + + output0 += outw * 4; + } + } + } + } +} diff --git a/src/layer/arm/convolution_winograd_transform_pack4_bf16s.h b/src/layer/arm/convolution_winograd_transform_pack4_bf16s.h new file mode 100644 index 00000000000..50b5836c8a3 --- /dev/null +++ b/src/layer/arm/convolution_winograd_transform_pack4_bf16s.h @@ -0,0 +1,535 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack4_bf16s_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + float tmp[8][8][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const unsigned short* r0 = img0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + float32x4_t _r00 = vcvt_f32_bf16(vld1_u16(r0)); + float32x4_t _r01 = vcvt_f32_bf16(vld1_u16(r0 + 4)); + float32x4_t _r02 = vcvt_f32_bf16(vld1_u16(r0 + 8)); + float32x4_t _r03 = vcvt_f32_bf16(vld1_u16(r0 + 12)); + float32x4_t _r04 = vcvt_f32_bf16(vld1_u16(r0 + 16)); + float32x4_t _r05 = vcvt_f32_bf16(vld1_u16(r0 + 20)); + float32x4_t _r06 = vcvt_f32_bf16(vld1_u16(r0 + 24)); + float32x4_t _r07 = vcvt_f32_bf16(vld1_u16(r0 + 28)); + + float32x4_t _tmp0m = vmlaq_n_f32(vsubq_f32(_r00, _r06), vsubq_f32(_r04, _r02), 5.25f); + float32x4_t _tmp7m = vmlaq_n_f32(vsubq_f32(_r07, _r01), vsubq_f32(_r03, _r05), 5.25f); + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[7][m], _tmp7m); + + float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_r02, _r06), _r04, 4.25f); + float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_r01, _r05), _r03, 4.25f); + + float32x4_t _tmp1m = vaddq_f32(_tmp12a, _tmp12b); + float32x4_t _tmp2m = vsubq_f32(_tmp12a, _tmp12b); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[2][m], _tmp2m); + + float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_r06, _r02, 0.25f), _r04, 1.25f); + float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); + + float32x4_t _tmp3m = vaddq_f32(_tmp34a, _tmp34b); + float32x4_t _tmp4m = vsubq_f32(_tmp34a, _tmp34b); + vst1q_f32(tmp[3][m], _tmp3m); + vst1q_f32(tmp[4][m], _tmp4m); + + float32x4_t _tmp56a = vmlaq_n_f32(_r06, vmlsq_n_f32(_r02, _r04, 1.25f), 4.f); + float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); + + float32x4_t _tmp5m = vaddq_f32(_tmp56a, _tmp56b); + float32x4_t _tmp6m = vsubq_f32(_tmp56a, _tmp56b); + vst1q_f32(tmp[5][m], _tmp5m); + vst1q_f32(tmp[6][m], _tmp6m); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 8; + float* r0_tm_3 = r0_tm_0 + tiles * 12; + float* r0_tm_4 = r0_tm_0 + tiles * 16; + float* r0_tm_5 = r0_tm_0 + tiles * 20; + float* r0_tm_6 = r0_tm_0 + tiles * 24; + float* r0_tm_7 = r0_tm_0 + tiles * 28; + + for (int m = 0; m < 8; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); + float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); + + float32x4_t _r0tm0 = vmlaq_n_f32(vsubq_f32(_tmp00, _tmp06), vsubq_f32(_tmp04, _tmp02), 5.25f); + float32x4_t _r0tm7 = vmlaq_n_f32(vsubq_f32(_tmp07, _tmp01), vsubq_f32(_tmp03, _tmp05), 5.25f); + + float32x4_t _tmp12a = vmlsq_n_f32(vaddq_f32(_tmp02, _tmp06), _tmp04, 4.25f); + float32x4_t _tmp12b = vmlsq_n_f32(vaddq_f32(_tmp01, _tmp05), _tmp03, 4.25f); + + float32x4_t _r0tm1 = vaddq_f32(_tmp12a, _tmp12b); + float32x4_t _r0tm2 = vsubq_f32(_tmp12a, _tmp12b); + + float32x4_t _tmp34a = vmlsq_n_f32(vmlaq_n_f32(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); + float32x4_t _tmp34b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); + + float32x4_t _r0tm3 = vaddq_f32(_tmp34a, _tmp34b); + float32x4_t _r0tm4 = vsubq_f32(_tmp34a, _tmp34b); + + float32x4_t _tmp56a = vmlaq_n_f32(_tmp06, vmlsq_n_f32(_tmp02, _tmp04, 1.25f), 4.f); + float32x4_t _tmp56b = vmlaq_n_f32(vmlsq_n_f32(vmulq_n_f32(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); + + float32x4_t _r0tm5 = vaddq_f32(_tmp56a, _tmp56b); + float32x4_t _r0tm6 = vsubq_f32(_tmp56a, _tmp56b); + + vst1q_f32(r0_tm_0, _r0tm0); + vst1q_f32(r0_tm_1, _r0tm1); + vst1q_f32(r0_tm_2, _r0tm2); + vst1q_f32(r0_tm_3, _r0tm3); + vst1q_f32(r0_tm_4, _r0tm4); + vst1q_f32(r0_tm_5, _r0tm5); + vst1q_f32(r0_tm_6, _r0tm6); + vst1q_f32(r0_tm_7, _r0tm7); + + r0_tm_0 += tiles * 32; + r0_tm_1 += tiles * 32; + r0_tm_2 += tiles * 32; + r0_tm_3 += tiles * 32; + r0_tm_4 += tiles * 32; + r0_tm_5 += tiles * 32; + r0_tm_6 += tiles * 32; + r0_tm_7 += tiles * 32; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack4_bf16s_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + float32x4_t _bias0 = biasptr ? vld1q_f32(biasptr + p * 4) : vdupq_n_f32(0.f); + + float tmp[6][8][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 8; + const float* output0_tm_3 = output0_tm_0 + tiles * 12; + const float* output0_tm_4 = output0_tm_0 + tiles * 16; + const float* output0_tm_5 = output0_tm_0 + tiles * 20; + const float* output0_tm_6 = output0_tm_0 + tiles * 24; + const float* output0_tm_7 = output0_tm_0 + tiles * 28; + + unsigned short* output0 = out0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); + float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); + float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); + float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); + float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); + float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); + float32x4_t _out0tm6 = vld1q_f32(output0_tm_6); + float32x4_t _out0tm7 = vld1q_f32(output0_tm_7); + + float32x4_t _tmp024a = vaddq_f32(_out0tm1, _out0tm2); + float32x4_t _tmp135a = vsubq_f32(_out0tm1, _out0tm2); + + float32x4_t _tmp024b = vaddq_f32(_out0tm3, _out0tm4); + float32x4_t _tmp135b = vsubq_f32(_out0tm3, _out0tm4); + + float32x4_t _tmp024c = vaddq_f32(_out0tm5, _out0tm6); + float32x4_t _tmp135c = vsubq_f32(_out0tm5, _out0tm6); + + float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f)); + float32x4_t _tmp2m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); + float32x4_t _tmp4m = vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[2][m], _tmp2m); + vst1q_f32(tmp[4][m], _tmp4m); + + float32x4_t _tmp1m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); + float32x4_t _tmp3m = vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); + float32x4_t _tmp5m = vaddq_f32(vaddq_f32(_out0tm7, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f)); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[3][m], _tmp3m); + vst1q_f32(tmp[5][m], _tmp5m); + + output0_tm_0 += tiles * 32; + output0_tm_1 += tiles * 32; + output0_tm_2 += tiles * 32; + output0_tm_3 += tiles * 32; + output0_tm_4 += tiles * 32; + output0_tm_5 += tiles * 32; + output0_tm_6 += tiles * 32; + output0_tm_7 += tiles * 32; + } + + for (int m = 0; m < 6; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + float32x4_t _tmp06 = vld1q_f32(tmp[m][6]); + float32x4_t _tmp07 = vld1q_f32(tmp[m][7]); + + float32x4_t _tmp024a = vaddq_f32(_tmp01, _tmp02); + float32x4_t _tmp135a = vsubq_f32(_tmp01, _tmp02); + + float32x4_t _tmp024b = vaddq_f32(_tmp03, _tmp04); + float32x4_t _tmp135b = vsubq_f32(_tmp03, _tmp04); + + float32x4_t _tmp024c = vaddq_f32(_tmp05, _tmp06); + float32x4_t _tmp135c = vsubq_f32(_tmp05, _tmp06); + + float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp024a), vmlaq_n_f32(_tmp024b, _tmp024c, 32.f))); + float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); + float32x4_t _out04 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); + vst1_u16(output0, vcvt_bf16_f32(_out00)); + vst1_u16(output0 + 8, vcvt_bf16_f32(_out02)); + vst1_u16(output0 + 16, vcvt_bf16_f32(_out04)); + + float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); + float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vmlaq_n_f32(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); + float32x4_t _out05 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp07, _tmp135a), vmlaq_n_f32(_tmp135c, _tmp135b, 32.f))); + vst1_u16(output0 + 4, vcvt_bf16_f32(_out01)); + vst1_u16(output0 + 12, vcvt_bf16_f32(_out03)); + vst1_u16(output0 + 20, vcvt_bf16_f32(_out05)); + + output0 += outw * 4; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_pack4_bf16s_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[6][6] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + float tmp[6][6][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const unsigned short* r0 = img0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + float32x4_t _r00 = vcvt_f32_bf16(vld1_u16(r0)); + float32x4_t _r01 = vcvt_f32_bf16(vld1_u16(r0 + 4)); + float32x4_t _r02 = vcvt_f32_bf16(vld1_u16(r0 + 8)); + float32x4_t _r03 = vcvt_f32_bf16(vld1_u16(r0 + 12)); + float32x4_t _r04 = vcvt_f32_bf16(vld1_u16(r0 + 16)); + float32x4_t _r05 = vcvt_f32_bf16(vld1_u16(r0 + 20)); + + float32x4_t _tmp0m = vmlsq_n_f32(vmlaq_n_f32(_r04, _r00, 4.f), _r02, 5.f); + float32x4_t _tmp1m = vmlsq_n_f32(vaddq_f32(_r04, _r03), vaddq_f32(_r01, _r02), 4.f); + float32x4_t _tmp2m = vmlaq_n_f32(vsubq_f32(_r04, _r03), vsubq_f32(_r01, _r02), 4.f); + float32x4_t _tmp3m = vmlsq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); + float32x4_t _tmp4m = vmlaq_n_f32(vsubq_f32(_r04, _r02), vsubq_f32(_r01, _r03), 2.f); + float32x4_t _tmp5m = vmlsq_n_f32(vmlaq_n_f32(_r05, _r01, 4.f), _r03, 5.f); + + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[2][m], _tmp2m); + vst1q_f32(tmp[3][m], _tmp3m); + vst1q_f32(tmp[4][m], _tmp4m); + vst1q_f32(tmp[5][m], _tmp5m); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 8; + float* r0_tm_3 = r0_tm_0 + tiles * 12; + float* r0_tm_4 = r0_tm_0 + tiles * 16; + float* r0_tm_5 = r0_tm_0 + tiles * 20; + + for (int m = 0; m < 6; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + + float32x4_t _r0tm0 = vmlsq_n_f32(vmlaq_n_f32(_tmp04, _tmp00, 4.f), _tmp02, 5.f); + float32x4_t _r0tm1 = vmlsq_n_f32(vaddq_f32(_tmp04, _tmp03), vaddq_f32(_tmp01, _tmp02), 4.f); + float32x4_t _r0tm2 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp03), vsubq_f32(_tmp01, _tmp02), 4.f); + float32x4_t _r0tm3 = vmlsq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); + float32x4_t _r0tm4 = vmlaq_n_f32(vsubq_f32(_tmp04, _tmp02), vsubq_f32(_tmp01, _tmp03), 2.f); + float32x4_t _r0tm5 = vmlsq_n_f32(vmlaq_n_f32(_tmp05, _tmp01, 4.f), _tmp03, 5.f); + + vst1q_f32(r0_tm_0, _r0tm0); + vst1q_f32(r0_tm_1, _r0tm1); + vst1q_f32(r0_tm_2, _r0tm2); + vst1q_f32(r0_tm_3, _r0tm3); + vst1q_f32(r0_tm_4, _r0tm4); + vst1q_f32(r0_tm_5, _r0tm5); + + r0_tm_0 += tiles * 24; + r0_tm_1 += tiles * 24; + r0_tm_2 += tiles * 24; + r0_tm_3 += tiles * 24; + r0_tm_4 += tiles * 24; + r0_tm_5 += tiles * 24; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_pack4_bf16s_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + float32x4_t _bias0 = biasptr ? vld1q_f32(biasptr + p * 4) : vdupq_n_f32(0.f); + + float tmp[4][6][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 8; + const float* output0_tm_3 = output0_tm_0 + tiles * 12; + const float* output0_tm_4 = output0_tm_0 + tiles * 16; + const float* output0_tm_5 = output0_tm_0 + tiles * 20; + + unsigned short* output0 = out0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + float32x4_t _out0tm0 = vld1q_f32(output0_tm_0); + float32x4_t _out0tm1 = vld1q_f32(output0_tm_1); + float32x4_t _out0tm2 = vld1q_f32(output0_tm_2); + float32x4_t _out0tm3 = vld1q_f32(output0_tm_3); + float32x4_t _out0tm4 = vld1q_f32(output0_tm_4); + float32x4_t _out0tm5 = vld1q_f32(output0_tm_5); + + float32x4_t _tmp02a = vaddq_f32(_out0tm1, _out0tm2); + float32x4_t _tmp13a = vsubq_f32(_out0tm1, _out0tm2); + + float32x4_t _tmp02b = vaddq_f32(_out0tm3, _out0tm4); + float32x4_t _tmp13b = vsubq_f32(_out0tm3, _out0tm4); + + float32x4_t _tmp0m = vaddq_f32(vaddq_f32(_out0tm0, _tmp02a), _tmp02b); + float32x4_t _tmp1m = vmlaq_n_f32(_tmp13a, _tmp13b, 2.f); + float32x4_t _tmp2m = vmlaq_n_f32(_tmp02a, _tmp02b, 4.f); + float32x4_t _tmp3m = vmlaq_n_f32(vaddq_f32(_out0tm5, _tmp13a), _tmp13b, 8.f); + + vst1q_f32(tmp[0][m], _tmp0m); + vst1q_f32(tmp[1][m], _tmp1m); + vst1q_f32(tmp[2][m], _tmp2m); + vst1q_f32(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 24; + output0_tm_1 += tiles * 24; + output0_tm_2 += tiles * 24; + output0_tm_3 += tiles * 24; + output0_tm_4 += tiles * 24; + output0_tm_5 += tiles * 24; + } + + for (int m = 0; m < 4; m++) + { + float32x4_t _tmp00 = vld1q_f32(tmp[m][0]); + float32x4_t _tmp01 = vld1q_f32(tmp[m][1]); + float32x4_t _tmp02 = vld1q_f32(tmp[m][2]); + float32x4_t _tmp03 = vld1q_f32(tmp[m][3]); + float32x4_t _tmp04 = vld1q_f32(tmp[m][4]); + float32x4_t _tmp05 = vld1q_f32(tmp[m][5]); + + float32x4_t _tmp02a = vaddq_f32(_tmp01, _tmp02); + float32x4_t _tmp13a = vsubq_f32(_tmp01, _tmp02); + + float32x4_t _tmp02b = vaddq_f32(_tmp03, _tmp04); + float32x4_t _tmp13b = vsubq_f32(_tmp03, _tmp04); + + float32x4_t _out00 = vaddq_f32(_bias0, vaddq_f32(vaddq_f32(_tmp00, _tmp02a), _tmp02b)); + float32x4_t _out01 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp13a, _tmp13b, 2.f)); + float32x4_t _out02 = vaddq_f32(_bias0, vmlaq_n_f32(_tmp02a, _tmp02b, 4.f)); + float32x4_t _out03 = vaddq_f32(_bias0, vmlaq_n_f32(vaddq_f32(_tmp05, _tmp13a), _tmp13b, 8.f)); + + vst1_u16(output0, vcvt_bf16_f32(_out00)); + vst1_u16(output0 + 4, vcvt_bf16_f32(_out01)); + vst1_u16(output0 + 8, vcvt_bf16_f32(_out02)); + vst1_u16(output0 + 12, vcvt_bf16_f32(_out03)); + + output0 += outw * 4; + } + } + } + } +} diff --git a/src/layer/arm/convolution_winograd_transform_pack4_fp16s.h b/src/layer/arm/convolution_winograd_transform_pack4_fp16s.h new file mode 100644 index 00000000000..3d5c1bf7ad5 --- /dev/null +++ b/src/layer/arm/convolution_winograd_transform_pack4_fp16s.h @@ -0,0 +1,313 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + __fp16 tmp[8][8][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* r0 = img0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + float16x4_t _r00 = vld1_f16(r0); + float16x4_t _r01 = vld1_f16(r0 + 4); + float16x4_t _r02 = vld1_f16(r0 + 8); + float16x4_t _r03 = vld1_f16(r0 + 12); + float16x4_t _r04 = vld1_f16(r0 + 16); + float16x4_t _r05 = vld1_f16(r0 + 20); + float16x4_t _r06 = vld1_f16(r0 + 24); + float16x4_t _r07 = vld1_f16(r0 + 28); + + float16x4_t _tmp0m = vfma_n_f16(vsub_f16(_r00, _r06), vsub_f16(_r04, _r02), 5.25f); + float16x4_t _tmp7m = vfma_n_f16(vsub_f16(_r07, _r01), vsub_f16(_r03, _r05), 5.25f); + vst1_f16(tmp[0][m], _tmp0m); + vst1_f16(tmp[7][m], _tmp7m); + + float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_r02, _r06), _r04, 4.25f); + float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_r01, _r05), _r03, 4.25f); + + float16x4_t _tmp1m = vadd_f16(_tmp12a, _tmp12b); + float16x4_t _tmp2m = vsub_f16(_tmp12a, _tmp12b); + vst1_f16(tmp[1][m], _tmp1m); + vst1_f16(tmp[2][m], _tmp2m); + + float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_r06, _r02, 0.25f), _r04, 1.25f); + float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); + + float16x4_t _tmp3m = vadd_f16(_tmp34a, _tmp34b); + float16x4_t _tmp4m = vsub_f16(_tmp34a, _tmp34b); + vst1_f16(tmp[3][m], _tmp3m); + vst1_f16(tmp[4][m], _tmp4m); + + float16x4_t _tmp56a = vfma_n_f16(_r06, vfms_n_f16(_r02, _r04, 1.25f), 4.f); + float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); + + float16x4_t _tmp5m = vadd_f16(_tmp56a, _tmp56b); + float16x4_t _tmp6m = vsub_f16(_tmp56a, _tmp56b); + vst1_f16(tmp[5][m], _tmp5m); + vst1_f16(tmp[6][m], _tmp6m); + + r0 += w * 4; + } + + __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tiles + j) * 4; + __fp16* r0_tm_1 = r0_tm_0 + tiles * 4; + __fp16* r0_tm_2 = r0_tm_0 + tiles * 8; + __fp16* r0_tm_3 = r0_tm_0 + tiles * 12; + __fp16* r0_tm_4 = r0_tm_0 + tiles * 16; + __fp16* r0_tm_5 = r0_tm_0 + tiles * 20; + __fp16* r0_tm_6 = r0_tm_0 + tiles * 24; + __fp16* r0_tm_7 = r0_tm_0 + tiles * 28; + + for (int m = 0; m < 8; m++) + { + float16x4_t _tmp00 = vld1_f16(tmp[m][0]); + float16x4_t _tmp01 = vld1_f16(tmp[m][1]); + float16x4_t _tmp02 = vld1_f16(tmp[m][2]); + float16x4_t _tmp03 = vld1_f16(tmp[m][3]); + float16x4_t _tmp04 = vld1_f16(tmp[m][4]); + float16x4_t _tmp05 = vld1_f16(tmp[m][5]); + float16x4_t _tmp06 = vld1_f16(tmp[m][6]); + float16x4_t _tmp07 = vld1_f16(tmp[m][7]); + + float16x4_t _r0tm0 = vfma_n_f16(vsub_f16(_tmp00, _tmp06), vsub_f16(_tmp04, _tmp02), 5.25f); + float16x4_t _r0tm7 = vfma_n_f16(vsub_f16(_tmp07, _tmp01), vsub_f16(_tmp03, _tmp05), 5.25f); + + float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_tmp02, _tmp06), _tmp04, 4.25f); + float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_tmp01, _tmp05), _tmp03, 4.25f); + + float16x4_t _r0tm1 = vadd_f16(_tmp12a, _tmp12b); + float16x4_t _r0tm2 = vsub_f16(_tmp12a, _tmp12b); + + float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); + float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); + + float16x4_t _r0tm3 = vadd_f16(_tmp34a, _tmp34b); + float16x4_t _r0tm4 = vsub_f16(_tmp34a, _tmp34b); + + float16x4_t _tmp56a = vfma_n_f16(_tmp06, vfms_n_f16(_tmp02, _tmp04, 1.25f), 4.f); + float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); + + float16x4_t _r0tm5 = vadd_f16(_tmp56a, _tmp56b); + float16x4_t _r0tm6 = vsub_f16(_tmp56a, _tmp56b); + + vst1_f16(r0_tm_0, _r0tm0); + vst1_f16(r0_tm_1, _r0tm1); + vst1_f16(r0_tm_2, _r0tm2); + vst1_f16(r0_tm_3, _r0tm3); + vst1_f16(r0_tm_4, _r0tm4); + vst1_f16(r0_tm_5, _r0tm5); + vst1_f16(r0_tm_6, _r0tm6); + vst1_f16(r0_tm_7, _r0tm7); + + r0_tm_0 += tiles * 32; + r0_tm_1 += tiles * 32; + r0_tm_2 += tiles * 32; + r0_tm_3 += tiles * 32; + r0_tm_4 += tiles * 32; + r0_tm_5 += tiles * 32; + r0_tm_6 += tiles * 32; + r0_tm_7 += tiles * 32; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack4_fp16sa_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const __fp16* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + float16x4_t _bias0 = biasptr ? vld1_f16(biasptr + p * 4) : vdup_n_f16(0.f); + + __fp16 tmp[6][8][4]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tiles + j) * 4; + const __fp16* output0_tm_1 = output0_tm_0 + tiles * 4; + const __fp16* output0_tm_2 = output0_tm_0 + tiles * 8; + const __fp16* output0_tm_3 = output0_tm_0 + tiles * 12; + const __fp16* output0_tm_4 = output0_tm_0 + tiles * 16; + const __fp16* output0_tm_5 = output0_tm_0 + tiles * 20; + const __fp16* output0_tm_6 = output0_tm_0 + tiles * 24; + const __fp16* output0_tm_7 = output0_tm_0 + tiles * 28; + + __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + float16x4_t _out0tm0 = vld1_f16(output0_tm_0); + float16x4_t _out0tm1 = vld1_f16(output0_tm_1); + float16x4_t _out0tm2 = vld1_f16(output0_tm_2); + float16x4_t _out0tm3 = vld1_f16(output0_tm_3); + float16x4_t _out0tm4 = vld1_f16(output0_tm_4); + float16x4_t _out0tm5 = vld1_f16(output0_tm_5); + float16x4_t _out0tm6 = vld1_f16(output0_tm_6); + float16x4_t _out0tm7 = vld1_f16(output0_tm_7); + + float16x4_t _tmp024a = vadd_f16(_out0tm1, _out0tm2); + float16x4_t _tmp135a = vsub_f16(_out0tm1, _out0tm2); + + float16x4_t _tmp024b = vadd_f16(_out0tm3, _out0tm4); + float16x4_t _tmp135b = vsub_f16(_out0tm3, _out0tm4); + + float16x4_t _tmp024c = vadd_f16(_out0tm5, _out0tm6); + float16x4_t _tmp135c = vsub_f16(_out0tm5, _out0tm6); + + float16x4_t _tmp0m = vadd_f16(vadd_f16(_out0tm0, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f)); + float16x4_t _tmp2m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); + float16x4_t _tmp4m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); + vst1_f16(tmp[0][m], _tmp0m); + vst1_f16(tmp[2][m], _tmp2m); + vst1_f16(tmp[4][m], _tmp4m); + + float16x4_t _tmp1m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); + float16x4_t _tmp3m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); + float16x4_t _tmp5m = vadd_f16(vadd_f16(_out0tm7, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f)); + vst1_f16(tmp[1][m], _tmp1m); + vst1_f16(tmp[3][m], _tmp3m); + vst1_f16(tmp[5][m], _tmp5m); + + output0_tm_0 += tiles * 32; + output0_tm_1 += tiles * 32; + output0_tm_2 += tiles * 32; + output0_tm_3 += tiles * 32; + output0_tm_4 += tiles * 32; + output0_tm_5 += tiles * 32; + output0_tm_6 += tiles * 32; + output0_tm_7 += tiles * 32; + } + + for (int m = 0; m < 6; m++) + { + float16x4_t _tmp00 = vld1_f16(tmp[m][0]); + float16x4_t _tmp01 = vld1_f16(tmp[m][1]); + float16x4_t _tmp02 = vld1_f16(tmp[m][2]); + float16x4_t _tmp03 = vld1_f16(tmp[m][3]); + float16x4_t _tmp04 = vld1_f16(tmp[m][4]); + float16x4_t _tmp05 = vld1_f16(tmp[m][5]); + float16x4_t _tmp06 = vld1_f16(tmp[m][6]); + float16x4_t _tmp07 = vld1_f16(tmp[m][7]); + + float16x4_t _tmp024a = vadd_f16(_tmp01, _tmp02); + float16x4_t _tmp135a = vsub_f16(_tmp01, _tmp02); + + float16x4_t _tmp024b = vadd_f16(_tmp03, _tmp04); + float16x4_t _tmp135b = vsub_f16(_tmp03, _tmp04); + + float16x4_t _tmp024c = vadd_f16(_tmp05, _tmp06); + float16x4_t _tmp135c = vsub_f16(_tmp05, _tmp06); + + float16x4_t _out00 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp00, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f))); + float16x4_t _out02 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); + float16x4_t _out04 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); + vst1_f16(output0, _out00); + vst1_f16(output0 + 8, _out02); + vst1_f16(output0 + 16, _out04); + + float16x4_t _out01 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); + float16x4_t _out03 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); + float16x4_t _out05 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp07, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f))); + vst1_f16(output0 + 4, _out01); + vst1_f16(output0 + 12, _out03); + vst1_f16(output0 + 20, _out05); + + output0 += outw * 4; + } + } + } + } +} diff --git a/src/layer/arm/convolution_winograd_transform_pack8_fp16s.h b/src/layer/arm/convolution_winograd_transform_pack8_fp16s.h new file mode 100644 index 00000000000..eb2754971c0 --- /dev/null +++ b/src/layer/arm/convolution_winograd_transform_pack8_fp16s.h @@ -0,0 +1,535 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + __fp16 tmp[8][8][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* r0 = img0.row(i * 6) + (j * 6) * 8; + + for (int m = 0; m < 8; m++) + { + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r01 = vld1q_f16(r0 + 8); + float16x8_t _r02 = vld1q_f16(r0 + 16); + float16x8_t _r03 = vld1q_f16(r0 + 24); + float16x8_t _r04 = vld1q_f16(r0 + 32); + float16x8_t _r05 = vld1q_f16(r0 + 40); + float16x8_t _r06 = vld1q_f16(r0 + 48); + float16x8_t _r07 = vld1q_f16(r0 + 56); + + float16x8_t _tmp0m = vfmaq_n_f16(vsubq_f16(_r00, _r06), vsubq_f16(_r04, _r02), 5.25f); + float16x8_t _tmp7m = vfmaq_n_f16(vsubq_f16(_r07, _r01), vsubq_f16(_r03, _r05), 5.25f); + vst1q_f16(tmp[0][m], _tmp0m); + vst1q_f16(tmp[7][m], _tmp7m); + + float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_r02, _r06), _r04, 4.25f); + float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_r01, _r05), _r03, 4.25f); + + float16x8_t _tmp1m = vaddq_f16(_tmp12a, _tmp12b); + float16x8_t _tmp2m = vsubq_f16(_tmp12a, _tmp12b); + vst1q_f16(tmp[1][m], _tmp1m); + vst1q_f16(tmp[2][m], _tmp2m); + + float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_r06, _r02, 0.25f), _r04, 1.25f); + float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f); + + float16x8_t _tmp3m = vaddq_f16(_tmp34a, _tmp34b); + float16x8_t _tmp4m = vsubq_f16(_tmp34a, _tmp34b); + vst1q_f16(tmp[3][m], _tmp3m); + vst1q_f16(tmp[4][m], _tmp4m); + + float16x8_t _tmp56a = vfmaq_n_f16(_r06, vfmsq_n_f16(_r02, _r04, 1.25f), 4.f); + float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f); + + float16x8_t _tmp5m = vaddq_f16(_tmp56a, _tmp56b); + float16x8_t _tmp6m = vsubq_f16(_tmp56a, _tmp56b); + vst1q_f16(tmp[5][m], _tmp5m); + vst1q_f16(tmp[6][m], _tmp6m); + + r0 += w * 8; + } + + __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tiles + j) * 8; + __fp16* r0_tm_1 = r0_tm_0 + tiles * 8; + __fp16* r0_tm_2 = r0_tm_0 + tiles * 16; + __fp16* r0_tm_3 = r0_tm_0 + tiles * 24; + __fp16* r0_tm_4 = r0_tm_0 + tiles * 32; + __fp16* r0_tm_5 = r0_tm_0 + tiles * 40; + __fp16* r0_tm_6 = r0_tm_0 + tiles * 48; + __fp16* r0_tm_7 = r0_tm_0 + tiles * 56; + + for (int m = 0; m < 8; m++) + { + float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); + float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); + float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); + float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); + float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); + float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); + float16x8_t _tmp06 = vld1q_f16(tmp[m][6]); + float16x8_t _tmp07 = vld1q_f16(tmp[m][7]); + + float16x8_t _r0tm0 = vfmaq_n_f16(vsubq_f16(_tmp00, _tmp06), vsubq_f16(_tmp04, _tmp02), 5.25f); + float16x8_t _r0tm7 = vfmaq_n_f16(vsubq_f16(_tmp07, _tmp01), vsubq_f16(_tmp03, _tmp05), 5.25f); + + float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_tmp02, _tmp06), _tmp04, 4.25f); + float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_tmp01, _tmp05), _tmp03, 4.25f); + + float16x8_t _r0tm1 = vaddq_f16(_tmp12a, _tmp12b); + float16x8_t _r0tm2 = vsubq_f16(_tmp12a, _tmp12b); + + float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f); + float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f); + + float16x8_t _r0tm3 = vaddq_f16(_tmp34a, _tmp34b); + float16x8_t _r0tm4 = vsubq_f16(_tmp34a, _tmp34b); + + float16x8_t _tmp56a = vfmaq_n_f16(_tmp06, vfmsq_n_f16(_tmp02, _tmp04, 1.25f), 4.f); + float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f); + + float16x8_t _r0tm5 = vaddq_f16(_tmp56a, _tmp56b); + float16x8_t _r0tm6 = vsubq_f16(_tmp56a, _tmp56b); + + vst1q_f16(r0_tm_0, _r0tm0); + vst1q_f16(r0_tm_1, _r0tm1); + vst1q_f16(r0_tm_2, _r0tm2); + vst1q_f16(r0_tm_3, _r0tm3); + vst1q_f16(r0_tm_4, _r0tm4); + vst1q_f16(r0_tm_5, _r0tm5); + vst1q_f16(r0_tm_6, _r0tm6); + vst1q_f16(r0_tm_7, _r0tm7); + + r0_tm_0 += tiles * 64; + r0_tm_1 += tiles * 64; + r0_tm_2 += tiles * 64; + r0_tm_3 += tiles * 64; + r0_tm_4 += tiles * 64; + r0_tm_5 += tiles * 64; + r0_tm_6 += tiles * 64; + r0_tm_7 += tiles * 64; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack8_fp16sa_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const __fp16* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + float16x8_t _bias0 = biasptr ? vld1q_f16(biasptr + p * 8) : vdupq_n_f16(0.f); + + __fp16 tmp[6][8][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tiles + j) * 8; + const __fp16* output0_tm_1 = output0_tm_0 + tiles * 8; + const __fp16* output0_tm_2 = output0_tm_0 + tiles * 16; + const __fp16* output0_tm_3 = output0_tm_0 + tiles * 24; + const __fp16* output0_tm_4 = output0_tm_0 + tiles * 32; + const __fp16* output0_tm_5 = output0_tm_0 + tiles * 40; + const __fp16* output0_tm_6 = output0_tm_0 + tiles * 48; + const __fp16* output0_tm_7 = output0_tm_0 + tiles * 56; + + __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 8; + + for (int m = 0; m < 8; m++) + { + float16x8_t _out0tm0 = vld1q_f16(output0_tm_0); + float16x8_t _out0tm1 = vld1q_f16(output0_tm_1); + float16x8_t _out0tm2 = vld1q_f16(output0_tm_2); + float16x8_t _out0tm3 = vld1q_f16(output0_tm_3); + float16x8_t _out0tm4 = vld1q_f16(output0_tm_4); + float16x8_t _out0tm5 = vld1q_f16(output0_tm_5); + float16x8_t _out0tm6 = vld1q_f16(output0_tm_6); + float16x8_t _out0tm7 = vld1q_f16(output0_tm_7); + + float16x8_t _tmp024a = vaddq_f16(_out0tm1, _out0tm2); + float16x8_t _tmp135a = vsubq_f16(_out0tm1, _out0tm2); + + float16x8_t _tmp024b = vaddq_f16(_out0tm3, _out0tm4); + float16x8_t _tmp135b = vsubq_f16(_out0tm3, _out0tm4); + + float16x8_t _tmp024c = vaddq_f16(_out0tm5, _out0tm6); + float16x8_t _tmp135c = vsubq_f16(_out0tm5, _out0tm6); + + float16x8_t _tmp0m = vaddq_f16(vaddq_f16(_out0tm0, _tmp024a), vfmaq_n_f16(_tmp024b, _tmp024c, 32.f)); + float16x8_t _tmp2m = vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f); + float16x8_t _tmp4m = vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f); + vst1q_f16(tmp[0][m], _tmp0m); + vst1q_f16(tmp[2][m], _tmp2m); + vst1q_f16(tmp[4][m], _tmp4m); + + float16x8_t _tmp1m = vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f); + float16x8_t _tmp3m = vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f); + float16x8_t _tmp5m = vaddq_f16(vaddq_f16(_out0tm7, _tmp135a), vfmaq_n_f16(_tmp135c, _tmp135b, 32.f)); + vst1q_f16(tmp[1][m], _tmp1m); + vst1q_f16(tmp[3][m], _tmp3m); + vst1q_f16(tmp[5][m], _tmp5m); + + output0_tm_0 += tiles * 64; + output0_tm_1 += tiles * 64; + output0_tm_2 += tiles * 64; + output0_tm_3 += tiles * 64; + output0_tm_4 += tiles * 64; + output0_tm_5 += tiles * 64; + output0_tm_6 += tiles * 64; + output0_tm_7 += tiles * 64; + } + + for (int m = 0; m < 6; m++) + { + float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); + float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); + float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); + float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); + float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); + float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); + float16x8_t _tmp06 = vld1q_f16(tmp[m][6]); + float16x8_t _tmp07 = vld1q_f16(tmp[m][7]); + + float16x8_t _tmp024a = vaddq_f16(_tmp01, _tmp02); + float16x8_t _tmp135a = vsubq_f16(_tmp01, _tmp02); + + float16x8_t _tmp024b = vaddq_f16(_tmp03, _tmp04); + float16x8_t _tmp135b = vsubq_f16(_tmp03, _tmp04); + + float16x8_t _tmp024c = vaddq_f16(_tmp05, _tmp06); + float16x8_t _tmp135c = vsubq_f16(_tmp05, _tmp06); + + float16x8_t _out00 = vaddq_f16(_bias0, vaddq_f16(vaddq_f16(_tmp00, _tmp024a), vfmaq_n_f16(_tmp024b, _tmp024c, 32.f))); + float16x8_t _out02 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f)); + float16x8_t _out04 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f)); + vst1q_f16(output0, _out00); + vst1q_f16(output0 + 16, _out02); + vst1q_f16(output0 + 32, _out04); + + float16x8_t _out01 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f)); + float16x8_t _out03 = vaddq_f16(_bias0, vfmaq_n_f16(vfmaq_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f)); + float16x8_t _out05 = vaddq_f16(_bias0, vaddq_f16(vaddq_f16(_tmp07, _tmp135a), vfmaq_n_f16(_tmp135c, _tmp135b, 32.f))); + vst1q_f16(output0 + 8, _out01); + vst1q_f16(output0 + 24, _out03); + vst1q_f16(output0 + 40, _out05); + + output0 += outw * 8; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_pack8_fp16sa_neon(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[6][6] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + __fp16 tmp[6][6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* r0 = img0.row(i * 4) + (j * 4) * 8; + + for (int m = 0; m < 6; m++) + { + float16x8_t _r00 = vld1q_f16(r0); + float16x8_t _r01 = vld1q_f16(r0 + 8); + float16x8_t _r02 = vld1q_f16(r0 + 16); + float16x8_t _r03 = vld1q_f16(r0 + 24); + float16x8_t _r04 = vld1q_f16(r0 + 32); + float16x8_t _r05 = vld1q_f16(r0 + 40); + + float16x8_t _tmp0m = vfmsq_n_f16(vfmaq_n_f16(_r04, _r00, 4.f), _r02, 5.f); + float16x8_t _tmp1m = vfmsq_n_f16(vaddq_f16(_r04, _r03), vaddq_f16(_r01, _r02), 4.f); + float16x8_t _tmp2m = vfmaq_n_f16(vsubq_f16(_r04, _r03), vsubq_f16(_r01, _r02), 4.f); + float16x8_t _tmp3m = vfmsq_n_f16(vsubq_f16(_r04, _r02), vsubq_f16(_r01, _r03), 2.f); + float16x8_t _tmp4m = vfmaq_n_f16(vsubq_f16(_r04, _r02), vsubq_f16(_r01, _r03), 2.f); + float16x8_t _tmp5m = vfmsq_n_f16(vfmaq_n_f16(_r05, _r01, 4.f), _r03, 5.f); + + vst1q_f16(tmp[0][m], _tmp0m); + vst1q_f16(tmp[1][m], _tmp1m); + vst1q_f16(tmp[2][m], _tmp2m); + vst1q_f16(tmp[3][m], _tmp3m); + vst1q_f16(tmp[4][m], _tmp4m); + vst1q_f16(tmp[5][m], _tmp5m); + + r0 += w * 8; + } + + __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tiles + j) * 8; + __fp16* r0_tm_1 = r0_tm_0 + tiles * 8; + __fp16* r0_tm_2 = r0_tm_0 + tiles * 16; + __fp16* r0_tm_3 = r0_tm_0 + tiles * 24; + __fp16* r0_tm_4 = r0_tm_0 + tiles * 32; + __fp16* r0_tm_5 = r0_tm_0 + tiles * 40; + + for (int m = 0; m < 6; m++) + { + float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); + float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); + float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); + float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); + float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); + float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); + + float16x8_t _r0tm0 = vfmsq_n_f16(vfmaq_n_f16(_tmp04, _tmp00, 4.f), _tmp02, 5.f); + float16x8_t _r0tm1 = vfmsq_n_f16(vaddq_f16(_tmp04, _tmp03), vaddq_f16(_tmp01, _tmp02), 4.f); + float16x8_t _r0tm2 = vfmaq_n_f16(vsubq_f16(_tmp04, _tmp03), vsubq_f16(_tmp01, _tmp02), 4.f); + float16x8_t _r0tm3 = vfmsq_n_f16(vsubq_f16(_tmp04, _tmp02), vsubq_f16(_tmp01, _tmp03), 2.f); + float16x8_t _r0tm4 = vfmaq_n_f16(vsubq_f16(_tmp04, _tmp02), vsubq_f16(_tmp01, _tmp03), 2.f); + float16x8_t _r0tm5 = vfmsq_n_f16(vfmaq_n_f16(_tmp05, _tmp01, 4.f), _tmp03, 5.f); + + vst1q_f16(r0_tm_0, _r0tm0); + vst1q_f16(r0_tm_1, _r0tm1); + vst1q_f16(r0_tm_2, _r0tm2); + vst1q_f16(r0_tm_3, _r0tm3); + vst1q_f16(r0_tm_4, _r0tm4); + vst1q_f16(r0_tm_5, _r0tm5); + + r0_tm_0 += tiles * 48; + r0_tm_1 += tiles * 48; + r0_tm_2 += tiles * 48; + r0_tm_3 += tiles * 48; + r0_tm_4 += tiles * 48; + r0_tm_5 += tiles * 48; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_pack8_fp16sa_neon(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const __fp16* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + float16x8_t _bias0 = biasptr ? vld1q_f16(biasptr + p * 8) : vdupq_n_f16(0.f); + + __fp16 tmp[4][6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tiles + j) * 8; + const __fp16* output0_tm_1 = output0_tm_0 + tiles * 8; + const __fp16* output0_tm_2 = output0_tm_0 + tiles * 16; + const __fp16* output0_tm_3 = output0_tm_0 + tiles * 24; + const __fp16* output0_tm_4 = output0_tm_0 + tiles * 32; + const __fp16* output0_tm_5 = output0_tm_0 + tiles * 40; + + __fp16* output0 = out0.row<__fp16>(i * 4) + (j * 4) * 8; + + for (int m = 0; m < 6; m++) + { + float16x8_t _out0tm0 = vld1q_f16(output0_tm_0); + float16x8_t _out0tm1 = vld1q_f16(output0_tm_1); + float16x8_t _out0tm2 = vld1q_f16(output0_tm_2); + float16x8_t _out0tm3 = vld1q_f16(output0_tm_3); + float16x8_t _out0tm4 = vld1q_f16(output0_tm_4); + float16x8_t _out0tm5 = vld1q_f16(output0_tm_5); + + float16x8_t _tmp02a = vaddq_f16(_out0tm1, _out0tm2); + float16x8_t _tmp13a = vsubq_f16(_out0tm1, _out0tm2); + + float16x8_t _tmp02b = vaddq_f16(_out0tm3, _out0tm4); + float16x8_t _tmp13b = vsubq_f16(_out0tm3, _out0tm4); + + float16x8_t _tmp0m = vaddq_f16(vaddq_f16(_out0tm0, _tmp02a), _tmp02b); + float16x8_t _tmp1m = vfmaq_n_f16(_tmp13a, _tmp13b, 2.f); + float16x8_t _tmp2m = vfmaq_n_f16(_tmp02a, _tmp02b, 4.f); + float16x8_t _tmp3m = vfmaq_n_f16(vaddq_f16(_out0tm5, _tmp13a), _tmp13b, 8.f); + + vst1q_f16(tmp[0][m], _tmp0m); + vst1q_f16(tmp[1][m], _tmp1m); + vst1q_f16(tmp[2][m], _tmp2m); + vst1q_f16(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 48; + output0_tm_1 += tiles * 48; + output0_tm_2 += tiles * 48; + output0_tm_3 += tiles * 48; + output0_tm_4 += tiles * 48; + output0_tm_5 += tiles * 48; + } + + for (int m = 0; m < 4; m++) + { + float16x8_t _tmp00 = vld1q_f16(tmp[m][0]); + float16x8_t _tmp01 = vld1q_f16(tmp[m][1]); + float16x8_t _tmp02 = vld1q_f16(tmp[m][2]); + float16x8_t _tmp03 = vld1q_f16(tmp[m][3]); + float16x8_t _tmp04 = vld1q_f16(tmp[m][4]); + float16x8_t _tmp05 = vld1q_f16(tmp[m][5]); + + float16x8_t _tmp02a = vaddq_f16(_tmp01, _tmp02); + float16x8_t _tmp13a = vsubq_f16(_tmp01, _tmp02); + + float16x8_t _tmp02b = vaddq_f16(_tmp03, _tmp04); + float16x8_t _tmp13b = vsubq_f16(_tmp03, _tmp04); + + float16x8_t _out00 = vaddq_f16(_bias0, vaddq_f16(vaddq_f16(_tmp00, _tmp02a), _tmp02b)); + float16x8_t _out01 = vaddq_f16(_bias0, vfmaq_n_f16(_tmp13a, _tmp13b, 2.f)); + float16x8_t _out02 = vaddq_f16(_bias0, vfmaq_n_f16(_tmp02a, _tmp02b, 4.f)); + float16x8_t _out03 = vaddq_f16(_bias0, vfmaq_n_f16(vaddq_f16(_tmp05, _tmp13a), _tmp13b, 8.f)); + + vst1q_f16(output0, _out00); + vst1q_f16(output0 + 8, _out01); + vst1q_f16(output0 + 16, _out02); + vst1q_f16(output0 + 24, _out03); + + output0 += outw * 8; + } + } + } + } +} diff --git a/src/layer/mips/convolution_3x3_pack4.h b/src/layer/mips/convolution_3x3_pack4.h index 6c5833cb1ea..c98a1acd224 100644 --- a/src/layer/mips/convolution_3x3_pack4.h +++ b/src/layer/mips/convolution_3x3_pack4.h @@ -93,7 +93,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack4_msa(const Mat& kernel, M } } -static void conv3x3s1_winograd64_pack4_msa(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4_msa(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -115,177 +115,15 @@ static void conv3x3s1_winograd64_pack4_msa(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; - - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[8][8][4]; - - v4f32 _v5_25 = __msa_fill_w_f32(5.25f); - v4f32 _vm4_25 = __msa_fill_w_f32(-4.25f); - v4f32 _vm1_25 = __msa_fill_w_f32(-1.25f); - v4f32 _v0_25 = __msa_fill_w_f32(0.25f); - v4f32 _vm2_5 = __msa_fill_w_f32(-2.5f); - v4f32 _v0_5 = __msa_fill_w_f32(0.5f); - v4f32 _v2 = __msa_fill_w_f32(2.f); - v4f32 _v4 = __msa_fill_w_f32(4.f); - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - v4f32 _r00 = (v4f32)__msa_ld_w(r0, 0); - v4f32 _r01 = (v4f32)__msa_ld_w(r0 + 4, 0); - v4f32 _r02 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0); - v4f32 _r03 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0); - v4f32 _r04 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0); - v4f32 _r05 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0); - v4f32 _r06 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0); - v4f32 _r07 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0); - - v4f32 _tmp0m = __msa_fmadd_w(__msa_fsub_w(_r00, _r06), _v5_25, __msa_fsub_w(_r04, _r02)); - v4f32 _tmp7m = __msa_fmadd_w(__msa_fsub_w(_r07, _r01), _v5_25, __msa_fsub_w(_r03, _r05)); - __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); - __msa_st_w((v4i32)_tmp7m, tmp[7][m], 0); - - v4f32 _tmp12a = __msa_fmadd_w(__msa_fadd_w(_r02, _r06), _vm4_25, _r04); - v4f32 _tmp12b = __msa_fmadd_w(__msa_fadd_w(_r01, _r05), _vm4_25, _r03); - - v4f32 _tmp1m = __msa_fadd_w(_tmp12a, _tmp12b); - v4f32 _tmp2m = __msa_fsub_w(_tmp12a, _tmp12b); - __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); - __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); - - v4f32 _tmp34a = __msa_fmadd_w(__msa_fmadd_w(_r06, _v0_25, _r02), _vm1_25, _r04); - v4f32 _tmp34b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_r01, _v0_5), _vm2_5, _r03), _v2, _r05); - - v4f32 _tmp3m = __msa_fadd_w(_tmp34a, _tmp34b); - v4f32 _tmp4m = __msa_fsub_w(_tmp34a, _tmp34b); - __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); - __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0); - - v4f32 _tmp56a = __msa_fmadd_w(_r06, _v4, __msa_fmadd_w(_r02, _vm1_25, _r04)); - v4f32 _tmp56b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_r01, _v2), _vm2_5, _r03), _v0_5, _r05); - - v4f32 _tmp5m = __msa_fadd_w(_tmp56a, _tmp56b); - v4f32 _tmp6m = __msa_fsub_w(_tmp56a, _tmp56b); - __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0); - __msa_st_w((v4i32)_tmp6m, tmp[6][m], 0); - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; - float* r0_tm_6 = r0_tm_0 + tiles * 4 * 6; - float* r0_tm_7 = r0_tm_0 + tiles * 4 * 7; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; - for (int m = 0; m < 8; m++) - { - v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); - v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); - v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); - v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); - v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); - v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); - v4f32 _tmp06 = (v4f32)__msa_ld_w(tmp[m][6], 0); - v4f32 _tmp07 = (v4f32)__msa_ld_w(tmp[m][7], 0); - - v4f32 _r0tm0 = __msa_fmadd_w(__msa_fsub_w(_tmp00, _tmp06), _v5_25, __msa_fsub_w(_tmp04, _tmp02)); - v4f32 _r0tm7 = __msa_fmadd_w(__msa_fsub_w(_tmp07, _tmp01), _v5_25, __msa_fsub_w(_tmp03, _tmp05)); - - v4f32 _tmp12a = __msa_fmadd_w(__msa_fadd_w(_tmp02, _tmp06), _vm4_25, _tmp04); - v4f32 _tmp12b = __msa_fmadd_w(__msa_fadd_w(_tmp01, _tmp05), _vm4_25, _tmp03); - - v4f32 _r0tm1 = __msa_fadd_w(_tmp12a, _tmp12b); - v4f32 _r0tm2 = __msa_fsub_w(_tmp12a, _tmp12b); - - v4f32 _tmp34a = __msa_fmadd_w(__msa_fmadd_w(_tmp06, _v0_25, _tmp02), _vm1_25, _tmp04); - v4f32 _tmp34b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_tmp01, _v0_5), _vm2_5, _tmp03), _v2, _tmp05); - - v4f32 _r0tm3 = __msa_fadd_w(_tmp34a, _tmp34b); - v4f32 _r0tm4 = __msa_fsub_w(_tmp34a, _tmp34b); - - v4f32 _tmp56a = __msa_fmadd_w(_tmp06, _v4, __msa_fmadd_w(_tmp02, _vm1_25, _tmp04)); - v4f32 _tmp56b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_tmp01, _v2), _vm2_5, _tmp03), _v0_5, _tmp05); - - v4f32 _r0tm5 = __msa_fadd_w(_tmp56a, _tmp56b); - v4f32 _r0tm6 = __msa_fsub_w(_tmp56a, _tmp56b); - - __msa_st_w((v4i32)_r0tm0, r0_tm_0, 0); - __msa_st_w((v4i32)_r0tm1, r0_tm_1, 0); - __msa_st_w((v4i32)_r0tm2, r0_tm_2, 0); - __msa_st_w((v4i32)_r0tm3, r0_tm_3, 0); - __msa_st_w((v4i32)_r0tm4, r0_tm_4, 0); - __msa_st_w((v4i32)_r0tm5, r0_tm_5, 0); - __msa_st_w((v4i32)_r0tm6, r0_tm_6, 0); - __msa_st_w((v4i32)_r0tm7, r0_tm_7, 0); - - r0_tm_0 += tiles * 4 * 8; - r0_tm_1 += tiles * 4 * 8; - r0_tm_2 += tiles * 4 * 8; - r0_tm_3 += tiles * 4 * 8; - r0_tm_4 += tiles * 4 * 8; - r0_tm_5 += tiles * 4 * 8; - r0_tm_6 += tiles * 4 * 8; - r0_tm_7 += tiles * 4 * 8; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack4_msa(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -739,145 +577,7 @@ static void conv3x3s1_winograd64_pack4_msa(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - v4f32 _bias0 = bias ? (v4f32)__msa_ld_w((const float*)bias + p * 4, 0) : (v4f32)__msa_fill_w(0); - - float tmp[6][8][4]; - - v4f32 _v32 = __msa_fill_w_f32(32.f); - v4f32 _v16 = __msa_fill_w_f32(16.f); - v4f32 _v8 = __msa_fill_w_f32(8.f); - v4f32 _v4 = __msa_fill_w_f32(4.f); - v4f32 _v2 = __msa_fill_w_f32(2.f); - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * 4 * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * 4 * 7; - - float* output0 = out0.row(i * 6) + (j * 6) * 4; - - // TODO msa optimize - for (int m = 0; m < 8; m++) - { - v4f32 _out0tm0 = (v4f32)__msa_ld_w(output0_tm_0, 0); - v4f32 _out0tm1 = (v4f32)__msa_ld_w(output0_tm_1, 0); - v4f32 _out0tm2 = (v4f32)__msa_ld_w(output0_tm_2, 0); - v4f32 _out0tm3 = (v4f32)__msa_ld_w(output0_tm_3, 0); - v4f32 _out0tm4 = (v4f32)__msa_ld_w(output0_tm_4, 0); - v4f32 _out0tm5 = (v4f32)__msa_ld_w(output0_tm_5, 0); - v4f32 _out0tm6 = (v4f32)__msa_ld_w(output0_tm_6, 0); - v4f32 _out0tm7 = (v4f32)__msa_ld_w(output0_tm_7, 0); - - v4f32 _tmp024a = __msa_fadd_w(_out0tm1, _out0tm2); - v4f32 _tmp135a = __msa_fsub_w(_out0tm1, _out0tm2); - - v4f32 _tmp024b = __msa_fadd_w(_out0tm3, _out0tm4); - v4f32 _tmp135b = __msa_fsub_w(_out0tm3, _out0tm4); - - v4f32 _tmp024c = __msa_fadd_w(_out0tm5, _out0tm6); - v4f32 _tmp135c = __msa_fsub_w(_out0tm5, _out0tm6); - - v4f32 _tmp0m = __msa_fadd_w(__msa_fadd_w(_out0tm0, _tmp024a), __msa_fmadd_w(_tmp024b, _v32, _tmp024c)); - v4f32 _tmp2m = __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v4, _tmp024b), _v8, _tmp024c); - v4f32 _tmp4m = __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v16, _tmp024b), _v2, _tmp024c); - __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); - __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); - __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0); - - v4f32 _tmp1m = __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v2, _tmp135b), _v16, _tmp135c); - v4f32 _tmp3m = __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v8, _tmp135b), _v4, _tmp135c); - v4f32 _tmp5m = __msa_fadd_w(__msa_fadd_w(_out0tm7, _tmp135a), __msa_fmadd_w(_tmp135c, _v32, _tmp135b)); - __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); - __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); - __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0); - - output0_tm_0 += tiles * 4 * 8; - output0_tm_1 += tiles * 4 * 8; - output0_tm_2 += tiles * 4 * 8; - output0_tm_3 += tiles * 4 * 8; - output0_tm_4 += tiles * 4 * 8; - output0_tm_5 += tiles * 4 * 8; - output0_tm_6 += tiles * 4 * 8; - output0_tm_7 += tiles * 4 * 8; - } - - for (int m = 0; m < 6; m++) - { - v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); - v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); - v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); - v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); - v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); - v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); - v4f32 _tmp06 = (v4f32)__msa_ld_w(tmp[m][6], 0); - v4f32 _tmp07 = (v4f32)__msa_ld_w(tmp[m][7], 0); - - v4f32 _tmp024a = __msa_fadd_w(_tmp01, _tmp02); - v4f32 _tmp135a = __msa_fsub_w(_tmp01, _tmp02); - - v4f32 _tmp024b = __msa_fadd_w(_tmp03, _tmp04); - v4f32 _tmp135b = __msa_fsub_w(_tmp03, _tmp04); - - v4f32 _tmp024c = __msa_fadd_w(_tmp05, _tmp06); - v4f32 _tmp135c = __msa_fsub_w(_tmp05, _tmp06); - - v4f32 _out00 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp00, _tmp024a), __msa_fmadd_w(_tmp024b, _v32, _tmp024c))); - v4f32 _out02 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v4, _tmp024b), _v8, _tmp024c)); - v4f32 _out04 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v16, _tmp024b), _v2, _tmp024c)); - __msa_st_w((v4i32)_out00, output0, 0); - __msa_st_w((v4i32)_out02, output0 + 4 * 2, 0); - __msa_st_w((v4i32)_out04, output0 + 4 * 4, 0); - - v4f32 _out01 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v2, _tmp135b), _v16, _tmp135c)); - v4f32 _out03 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v8, _tmp135b), _v4, _tmp135c)); - v4f32 _out05 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp07, _tmp135a), __msa_fmadd_w(_tmp135c, _v32, _tmp135b))); - __msa_st_w((v4i32)_out01, output0 + 4, 0); - __msa_st_w((v4i32)_out03, output0 + 4 * 3, 0); - __msa_st_w((v4i32)_out05, output0 + 4 * 5, 0); - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack4_msa(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -963,7 +663,7 @@ static void conv3x3s1_winograd42_transform_kernel_pack4_msa(const Mat& kernel, M } } -static void conv3x3s1_winograd42_pack4_msa(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_pack4_msa(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -985,121 +685,15 @@ static void conv3x3s1_winograd42_pack4_msa(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 4u * elempack, elempack, opt.workspace_allocator); + int w_tiles = outw / 4; + int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - float tmp[6][6][4]; - - v4f32 _vm5 = __msa_fill_w_f32(-5.f); - v4f32 _vm4 = __msa_fill_w_f32(-4.f); - v4f32 _v4 = __msa_fill_w_f32(4.f); - v4f32 _vm2 = __msa_fill_w_f32(-2.f); - v4f32 _v2 = __msa_fill_w_f32(2.f); - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const float* r0 = img0.row(i * 4) + (j * 4) * 4; - - for (int m = 0; m < 6; m++) - { - v4f32 _r00 = (v4f32)__msa_ld_w(r0, 0); - v4f32 _r01 = (v4f32)__msa_ld_w(r0 + 4, 0); - v4f32 _r02 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0); - v4f32 _r03 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0); - v4f32 _r04 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0); - v4f32 _r05 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0); - - v4f32 _tmp0m = __msa_fmadd_w(__msa_fmadd_w(_r04, _v4, _r00), _vm5, _r02); - v4f32 _tmp1m = __msa_fmadd_w(__msa_fadd_w(_r04, _r03), _vm4, __msa_fadd_w(_r01, _r02)); - v4f32 _tmp2m = __msa_fmadd_w(__msa_fsub_w(_r04, _r03), _v4, __msa_fsub_w(_r01, _r02)); - v4f32 _tmp3m = __msa_fmadd_w(__msa_fsub_w(_r04, _r02), _vm2, __msa_fsub_w(_r01, _r03)); - v4f32 _tmp4m = __msa_fmadd_w(__msa_fsub_w(_r04, _r02), _v2, __msa_fsub_w(_r01, _r03)); - v4f32 _tmp5m = __msa_fmadd_w(__msa_fmadd_w(_r05, _v4, _r01), _vm5, _r03); - - __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); - __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); - __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); - __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); - __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0); - __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0); - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; - - for (int m = 0; m < 6; m++) - { - v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); - v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); - v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); - v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); - v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); - v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); - - v4f32 _r0tm0 = __msa_fmadd_w(__msa_fmadd_w(_tmp04, _v4, _tmp00), _vm5, _tmp02); - v4f32 _r0tm1 = __msa_fmadd_w(__msa_fadd_w(_tmp04, _tmp03), _vm4, __msa_fadd_w(_tmp01, _tmp02)); - v4f32 _r0tm2 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp03), _v4, __msa_fsub_w(_tmp01, _tmp02)); - v4f32 _r0tm3 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp02), _vm2, __msa_fsub_w(_tmp01, _tmp03)); - v4f32 _r0tm4 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp02), _v2, __msa_fsub_w(_tmp01, _tmp03)); - v4f32 _r0tm5 = __msa_fmadd_w(__msa_fmadd_w(_tmp05, _v4, _tmp01), _vm5, _tmp03); - - __msa_st_w((v4i32)_r0tm0, r0_tm_0, 0); - __msa_st_w((v4i32)_r0tm1, r0_tm_1, 0); - __msa_st_w((v4i32)_r0tm2, r0_tm_2, 0); - __msa_st_w((v4i32)_r0tm3, r0_tm_3, 0); - __msa_st_w((v4i32)_r0tm4, r0_tm_4, 0); - __msa_st_w((v4i32)_r0tm5, r0_tm_5, 0); - - r0_tm_0 += tiles * 4 * 6; - r0_tm_1 += tiles * 4 * 6; - r0_tm_2 += tiles * 4 * 6; - r0_tm_3 += tiles * 4 * 6; - r0_tm_4 += tiles * 4 * 6; - r0_tm_5 += tiles * 4 * 6; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_pack4_msa(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1553,117 +1147,7 @@ static void conv3x3s1_winograd42_pack4_msa(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - v4f32 _bias0 = bias ? (v4f32)__msa_ld_w((const float*)bias + p * 4, 0) : (v4f32)__msa_fill_w(0); - - float tmp[4][6][4]; - - v4f32 _v2 = __msa_fill_w_f32(2.f); - v4f32 _v4 = __msa_fill_w_f32(4.f); - v4f32 _v8 = __msa_fill_w_f32(8.f); - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; - - float* output0 = out0.row(i * 4) + (j * 4) * 4; - - // TODO msa optimize - for (int m = 0; m < 6; m++) - { - v4f32 _out0tm0 = (v4f32)__msa_ld_w(output0_tm_0, 0); - v4f32 _out0tm1 = (v4f32)__msa_ld_w(output0_tm_1, 0); - v4f32 _out0tm2 = (v4f32)__msa_ld_w(output0_tm_2, 0); - v4f32 _out0tm3 = (v4f32)__msa_ld_w(output0_tm_3, 0); - v4f32 _out0tm4 = (v4f32)__msa_ld_w(output0_tm_4, 0); - v4f32 _out0tm5 = (v4f32)__msa_ld_w(output0_tm_5, 0); - - v4f32 _tmp02a = __msa_fadd_w(_out0tm1, _out0tm2); - v4f32 _tmp13a = __msa_fsub_w(_out0tm1, _out0tm2); - - v4f32 _tmp02b = __msa_fadd_w(_out0tm3, _out0tm4); - v4f32 _tmp13b = __msa_fsub_w(_out0tm3, _out0tm4); - - v4f32 _tmp0m = __msa_fadd_w(__msa_fadd_w(_out0tm0, _tmp02a), _tmp02b); - v4f32 _tmp1m = __msa_fmadd_w(_tmp13a, _v2, _tmp13b); - v4f32 _tmp2m = __msa_fmadd_w(_tmp02a, _v4, _tmp02b); - v4f32 _tmp3m = __msa_fmadd_w(__msa_fadd_w(_out0tm5, _tmp13a), _v8, _tmp13b); - - __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); - __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); - __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); - __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); - - output0_tm_0 += tiles * 4 * 6; - output0_tm_1 += tiles * 4 * 6; - output0_tm_2 += tiles * 4 * 6; - output0_tm_3 += tiles * 4 * 6; - output0_tm_4 += tiles * 4 * 6; - output0_tm_5 += tiles * 4 * 6; - } - - for (int m = 0; m < 4; m++) - { - v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); - v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); - v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); - v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); - v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); - v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); - - v4f32 _tmp02a = __msa_fadd_w(_tmp01, _tmp02); - v4f32 _tmp13a = __msa_fsub_w(_tmp01, _tmp02); - - v4f32 _tmp02b = __msa_fadd_w(_tmp03, _tmp04); - v4f32 _tmp13b = __msa_fsub_w(_tmp03, _tmp04); - - v4f32 _out00 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp00, _tmp02a), _tmp02b)); - v4f32 _out01 = __msa_fadd_w(_bias0, __msa_fmadd_w(_tmp13a, _v2, _tmp13b)); - v4f32 _out02 = __msa_fadd_w(_bias0, __msa_fmadd_w(_tmp02a, _v4, _tmp02b)); - v4f32 _out03 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fadd_w(_tmp05, _tmp13a), _v8, _tmp13b)); - - __msa_st_w((v4i32)_out00, output0, 0); - __msa_st_w((v4i32)_out01, output0 + 4, 0); - __msa_st_w((v4i32)_out02, output0 + 4 * 2, 0); - __msa_st_w((v4i32)_out03, output0 + 4 * 3, 0); - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd42_transform_output_pack4_msa(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/mips/convolution_mips.cpp b/src/layer/mips/convolution_mips.cpp index a3e012c5793..9411add0f06 100644 --- a/src/layer/mips/convolution_mips.cpp +++ b/src/layer/mips/convolution_mips.cpp @@ -45,6 +45,7 @@ namespace ncnn { #include "convolution_sgemm_pack4.h" #include "convolution_sgemm_pack4to1.h" +#include "convolution_winograd_transform_pack4.h" #include "convolution_1x1_pack4.h" #include "convolution_1x1_pack4to1.h" #include "convolution_3x3_pack4.h" diff --git a/src/layer/mips/convolution_winograd_transform_pack4.h b/src/layer/mips/convolution_winograd_transform_pack4.h new file mode 100644 index 00000000000..78cad25d3c2 --- /dev/null +++ b/src/layer/mips/convolution_winograd_transform_pack4.h @@ -0,0 +1,560 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack4_msa(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + float tmp[8][8][4]; + + v4f32 _v5_25 = __msa_fill_w_f32(5.25f); + v4f32 _vm4_25 = __msa_fill_w_f32(-4.25f); + v4f32 _vm1_25 = __msa_fill_w_f32(-1.25f); + v4f32 _v0_25 = __msa_fill_w_f32(0.25f); + v4f32 _vm2_5 = __msa_fill_w_f32(-2.5f); + v4f32 _v0_5 = __msa_fill_w_f32(0.5f); + v4f32 _v2 = __msa_fill_w_f32(2.f); + v4f32 _v4 = __msa_fill_w_f32(4.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + v4f32 _r00 = (v4f32)__msa_ld_w(r0, 0); + v4f32 _r01 = (v4f32)__msa_ld_w(r0 + 4, 0); + v4f32 _r02 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0); + v4f32 _r03 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0); + v4f32 _r04 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0); + v4f32 _r05 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0); + v4f32 _r06 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0); + v4f32 _r07 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0); + + v4f32 _tmp0m = __msa_fmadd_w(__msa_fsub_w(_r00, _r06), _v5_25, __msa_fsub_w(_r04, _r02)); + v4f32 _tmp7m = __msa_fmadd_w(__msa_fsub_w(_r07, _r01), _v5_25, __msa_fsub_w(_r03, _r05)); + __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); + __msa_st_w((v4i32)_tmp7m, tmp[7][m], 0); + + v4f32 _tmp12a = __msa_fmadd_w(__msa_fadd_w(_r02, _r06), _vm4_25, _r04); + v4f32 _tmp12b = __msa_fmadd_w(__msa_fadd_w(_r01, _r05), _vm4_25, _r03); + + v4f32 _tmp1m = __msa_fadd_w(_tmp12a, _tmp12b); + v4f32 _tmp2m = __msa_fsub_w(_tmp12a, _tmp12b); + __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); + __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); + + v4f32 _tmp34a = __msa_fmadd_w(__msa_fmadd_w(_r06, _v0_25, _r02), _vm1_25, _r04); + v4f32 _tmp34b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_r01, _v0_5), _vm2_5, _r03), _v2, _r05); + + v4f32 _tmp3m = __msa_fadd_w(_tmp34a, _tmp34b); + v4f32 _tmp4m = __msa_fsub_w(_tmp34a, _tmp34b); + __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); + __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0); + + v4f32 _tmp56a = __msa_fmadd_w(_r06, _v4, __msa_fmadd_w(_r02, _vm1_25, _r04)); + v4f32 _tmp56b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_r01, _v2), _vm2_5, _r03), _v0_5, _r05); + + v4f32 _tmp5m = __msa_fadd_w(_tmp56a, _tmp56b); + v4f32 _tmp6m = __msa_fsub_w(_tmp56a, _tmp56b); + __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0); + __msa_st_w((v4i32)_tmp6m, tmp[6][m], 0); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; + float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; + float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; + float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; + float* r0_tm_6 = r0_tm_0 + tiles * 4 * 6; + float* r0_tm_7 = r0_tm_0 + tiles * 4 * 7; + + for (int m = 0; m < 8; m++) + { + v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); + v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); + v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); + v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); + v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); + v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); + v4f32 _tmp06 = (v4f32)__msa_ld_w(tmp[m][6], 0); + v4f32 _tmp07 = (v4f32)__msa_ld_w(tmp[m][7], 0); + + v4f32 _r0tm0 = __msa_fmadd_w(__msa_fsub_w(_tmp00, _tmp06), _v5_25, __msa_fsub_w(_tmp04, _tmp02)); + v4f32 _r0tm7 = __msa_fmadd_w(__msa_fsub_w(_tmp07, _tmp01), _v5_25, __msa_fsub_w(_tmp03, _tmp05)); + + v4f32 _tmp12a = __msa_fmadd_w(__msa_fadd_w(_tmp02, _tmp06), _vm4_25, _tmp04); + v4f32 _tmp12b = __msa_fmadd_w(__msa_fadd_w(_tmp01, _tmp05), _vm4_25, _tmp03); + + v4f32 _r0tm1 = __msa_fadd_w(_tmp12a, _tmp12b); + v4f32 _r0tm2 = __msa_fsub_w(_tmp12a, _tmp12b); + + v4f32 _tmp34a = __msa_fmadd_w(__msa_fmadd_w(_tmp06, _v0_25, _tmp02), _vm1_25, _tmp04); + v4f32 _tmp34b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_tmp01, _v0_5), _vm2_5, _tmp03), _v2, _tmp05); + + v4f32 _r0tm3 = __msa_fadd_w(_tmp34a, _tmp34b); + v4f32 _r0tm4 = __msa_fsub_w(_tmp34a, _tmp34b); + + v4f32 _tmp56a = __msa_fmadd_w(_tmp06, _v4, __msa_fmadd_w(_tmp02, _vm1_25, _tmp04)); + v4f32 _tmp56b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_tmp01, _v2), _vm2_5, _tmp03), _v0_5, _tmp05); + + v4f32 _r0tm5 = __msa_fadd_w(_tmp56a, _tmp56b); + v4f32 _r0tm6 = __msa_fsub_w(_tmp56a, _tmp56b); + + __msa_st_w((v4i32)_r0tm0, r0_tm_0, 0); + __msa_st_w((v4i32)_r0tm1, r0_tm_1, 0); + __msa_st_w((v4i32)_r0tm2, r0_tm_2, 0); + __msa_st_w((v4i32)_r0tm3, r0_tm_3, 0); + __msa_st_w((v4i32)_r0tm4, r0_tm_4, 0); + __msa_st_w((v4i32)_r0tm5, r0_tm_5, 0); + __msa_st_w((v4i32)_r0tm6, r0_tm_6, 0); + __msa_st_w((v4i32)_r0tm7, r0_tm_7, 0); + + r0_tm_0 += tiles * 4 * 8; + r0_tm_1 += tiles * 4 * 8; + r0_tm_2 += tiles * 4 * 8; + r0_tm_3 += tiles * 4 * 8; + r0_tm_4 += tiles * 4 * 8; + r0_tm_5 += tiles * 4 * 8; + r0_tm_6 += tiles * 4 * 8; + r0_tm_7 += tiles * 4 * 8; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack4_msa(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + v4f32 _bias0 = biasptr ? (v4f32)__msa_ld_w(biasptr + p * 4, 0) : (v4f32)__msa_fill_w(0); + + float tmp[6][8][4]; + + v4f32 _v32 = __msa_fill_w_f32(32.f); + v4f32 _v16 = __msa_fill_w_f32(16.f); + v4f32 _v8 = __msa_fill_w_f32(8.f); + v4f32 _v4 = __msa_fill_w_f32(4.f); + v4f32 _v2 = __msa_fill_w_f32(2.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; + const float* output0_tm_6 = output0_tm_0 + tiles * 4 * 6; + const float* output0_tm_7 = output0_tm_0 + tiles * 4 * 7; + + float* output0 = out0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + v4f32 _out0tm0 = (v4f32)__msa_ld_w(output0_tm_0, 0); + v4f32 _out0tm1 = (v4f32)__msa_ld_w(output0_tm_1, 0); + v4f32 _out0tm2 = (v4f32)__msa_ld_w(output0_tm_2, 0); + v4f32 _out0tm3 = (v4f32)__msa_ld_w(output0_tm_3, 0); + v4f32 _out0tm4 = (v4f32)__msa_ld_w(output0_tm_4, 0); + v4f32 _out0tm5 = (v4f32)__msa_ld_w(output0_tm_5, 0); + v4f32 _out0tm6 = (v4f32)__msa_ld_w(output0_tm_6, 0); + v4f32 _out0tm7 = (v4f32)__msa_ld_w(output0_tm_7, 0); + + v4f32 _tmp024a = __msa_fadd_w(_out0tm1, _out0tm2); + v4f32 _tmp135a = __msa_fsub_w(_out0tm1, _out0tm2); + + v4f32 _tmp024b = __msa_fadd_w(_out0tm3, _out0tm4); + v4f32 _tmp135b = __msa_fsub_w(_out0tm3, _out0tm4); + + v4f32 _tmp024c = __msa_fadd_w(_out0tm5, _out0tm6); + v4f32 _tmp135c = __msa_fsub_w(_out0tm5, _out0tm6); + + v4f32 _tmp0m = __msa_fadd_w(__msa_fadd_w(_out0tm0, _tmp024a), __msa_fmadd_w(_tmp024b, _v32, _tmp024c)); + v4f32 _tmp2m = __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v4, _tmp024b), _v8, _tmp024c); + v4f32 _tmp4m = __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v16, _tmp024b), _v2, _tmp024c); + __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); + __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); + __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0); + + v4f32 _tmp1m = __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v2, _tmp135b), _v16, _tmp135c); + v4f32 _tmp3m = __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v8, _tmp135b), _v4, _tmp135c); + v4f32 _tmp5m = __msa_fadd_w(__msa_fadd_w(_out0tm7, _tmp135a), __msa_fmadd_w(_tmp135c, _v32, _tmp135b)); + __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); + __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); + __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0); + + output0_tm_0 += tiles * 4 * 8; + output0_tm_1 += tiles * 4 * 8; + output0_tm_2 += tiles * 4 * 8; + output0_tm_3 += tiles * 4 * 8; + output0_tm_4 += tiles * 4 * 8; + output0_tm_5 += tiles * 4 * 8; + output0_tm_6 += tiles * 4 * 8; + output0_tm_7 += tiles * 4 * 8; + } + + for (int m = 0; m < 6; m++) + { + v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); + v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); + v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); + v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); + v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); + v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); + v4f32 _tmp06 = (v4f32)__msa_ld_w(tmp[m][6], 0); + v4f32 _tmp07 = (v4f32)__msa_ld_w(tmp[m][7], 0); + + v4f32 _tmp024a = __msa_fadd_w(_tmp01, _tmp02); + v4f32 _tmp135a = __msa_fsub_w(_tmp01, _tmp02); + + v4f32 _tmp024b = __msa_fadd_w(_tmp03, _tmp04); + v4f32 _tmp135b = __msa_fsub_w(_tmp03, _tmp04); + + v4f32 _tmp024c = __msa_fadd_w(_tmp05, _tmp06); + v4f32 _tmp135c = __msa_fsub_w(_tmp05, _tmp06); + + v4f32 _out00 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp00, _tmp024a), __msa_fmadd_w(_tmp024b, _v32, _tmp024c))); + v4f32 _out02 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v4, _tmp024b), _v8, _tmp024c)); + v4f32 _out04 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v16, _tmp024b), _v2, _tmp024c)); + __msa_st_w((v4i32)_out00, output0, 0); + __msa_st_w((v4i32)_out02, output0 + 4 * 2, 0); + __msa_st_w((v4i32)_out04, output0 + 4 * 4, 0); + + v4f32 _out01 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v2, _tmp135b), _v16, _tmp135c)); + v4f32 _out03 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v8, _tmp135b), _v4, _tmp135c)); + v4f32 _out05 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp07, _tmp135a), __msa_fmadd_w(_tmp135c, _v32, _tmp135b))); + __msa_st_w((v4i32)_out01, output0 + 4, 0); + __msa_st_w((v4i32)_out03, output0 + 4 * 3, 0); + __msa_st_w((v4i32)_out05, output0 + 4 * 5, 0); + + output0 += outw * 4; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_pack4_msa(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[6][6] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + float tmp[6][6][4]; + + v4f32 _vm5 = __msa_fill_w_f32(-5.f); + v4f32 _vm4 = __msa_fill_w_f32(-4.f); + v4f32 _v4 = __msa_fill_w_f32(4.f); + v4f32 _vm2 = __msa_fill_w_f32(-2.f); + v4f32 _v2 = __msa_fill_w_f32(2.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + v4f32 _r00 = (v4f32)__msa_ld_w(r0, 0); + v4f32 _r01 = (v4f32)__msa_ld_w(r0 + 4, 0); + v4f32 _r02 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0); + v4f32 _r03 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0); + v4f32 _r04 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0); + v4f32 _r05 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0); + + v4f32 _tmp0m = __msa_fmadd_w(__msa_fmadd_w(_r04, _v4, _r00), _vm5, _r02); + v4f32 _tmp1m = __msa_fmadd_w(__msa_fadd_w(_r04, _r03), _vm4, __msa_fadd_w(_r01, _r02)); + v4f32 _tmp2m = __msa_fmadd_w(__msa_fsub_w(_r04, _r03), _v4, __msa_fsub_w(_r01, _r02)); + v4f32 _tmp3m = __msa_fmadd_w(__msa_fsub_w(_r04, _r02), _vm2, __msa_fsub_w(_r01, _r03)); + v4f32 _tmp4m = __msa_fmadd_w(__msa_fsub_w(_r04, _r02), _v2, __msa_fsub_w(_r01, _r03)); + v4f32 _tmp5m = __msa_fmadd_w(__msa_fmadd_w(_r05, _v4, _r01), _vm5, _r03); + + __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); + __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); + __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); + __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); + __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0); + __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; + float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; + float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; + float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; + + for (int m = 0; m < 6; m++) + { + v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); + v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); + v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); + v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); + v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); + v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); + + v4f32 _r0tm0 = __msa_fmadd_w(__msa_fmadd_w(_tmp04, _v4, _tmp00), _vm5, _tmp02); + v4f32 _r0tm1 = __msa_fmadd_w(__msa_fadd_w(_tmp04, _tmp03), _vm4, __msa_fadd_w(_tmp01, _tmp02)); + v4f32 _r0tm2 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp03), _v4, __msa_fsub_w(_tmp01, _tmp02)); + v4f32 _r0tm3 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp02), _vm2, __msa_fsub_w(_tmp01, _tmp03)); + v4f32 _r0tm4 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp02), _v2, __msa_fsub_w(_tmp01, _tmp03)); + v4f32 _r0tm5 = __msa_fmadd_w(__msa_fmadd_w(_tmp05, _v4, _tmp01), _vm5, _tmp03); + + __msa_st_w((v4i32)_r0tm0, r0_tm_0, 0); + __msa_st_w((v4i32)_r0tm1, r0_tm_1, 0); + __msa_st_w((v4i32)_r0tm2, r0_tm_2, 0); + __msa_st_w((v4i32)_r0tm3, r0_tm_3, 0); + __msa_st_w((v4i32)_r0tm4, r0_tm_4, 0); + __msa_st_w((v4i32)_r0tm5, r0_tm_5, 0); + + r0_tm_0 += tiles * 4 * 6; + r0_tm_1 += tiles * 4 * 6; + r0_tm_2 += tiles * 4 * 6; + r0_tm_3 += tiles * 4 * 6; + r0_tm_4 += tiles * 4 * 6; + r0_tm_5 += tiles * 4 * 6; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_pack4_msa(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + v4f32 _bias0 = biasptr ? (v4f32)__msa_ld_w(biasptr + p * 4, 0) : (v4f32)__msa_fill_w(0); + + float tmp[4][6][4]; + + v4f32 _v2 = __msa_fill_w_f32(2.f); + v4f32 _v4 = __msa_fill_w_f32(4.f); + v4f32 _v8 = __msa_fill_w_f32(8.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; + + float* output0 = out0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + v4f32 _out0tm0 = (v4f32)__msa_ld_w(output0_tm_0, 0); + v4f32 _out0tm1 = (v4f32)__msa_ld_w(output0_tm_1, 0); + v4f32 _out0tm2 = (v4f32)__msa_ld_w(output0_tm_2, 0); + v4f32 _out0tm3 = (v4f32)__msa_ld_w(output0_tm_3, 0); + v4f32 _out0tm4 = (v4f32)__msa_ld_w(output0_tm_4, 0); + v4f32 _out0tm5 = (v4f32)__msa_ld_w(output0_tm_5, 0); + + v4f32 _tmp02a = __msa_fadd_w(_out0tm1, _out0tm2); + v4f32 _tmp13a = __msa_fsub_w(_out0tm1, _out0tm2); + + v4f32 _tmp02b = __msa_fadd_w(_out0tm3, _out0tm4); + v4f32 _tmp13b = __msa_fsub_w(_out0tm3, _out0tm4); + + v4f32 _tmp0m = __msa_fadd_w(__msa_fadd_w(_out0tm0, _tmp02a), _tmp02b); + v4f32 _tmp1m = __msa_fmadd_w(_tmp13a, _v2, _tmp13b); + v4f32 _tmp2m = __msa_fmadd_w(_tmp02a, _v4, _tmp02b); + v4f32 _tmp3m = __msa_fmadd_w(__msa_fadd_w(_out0tm5, _tmp13a), _v8, _tmp13b); + + __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0); + __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0); + __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0); + __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0); + + output0_tm_0 += tiles * 4 * 6; + output0_tm_1 += tiles * 4 * 6; + output0_tm_2 += tiles * 4 * 6; + output0_tm_3 += tiles * 4 * 6; + output0_tm_4 += tiles * 4 * 6; + output0_tm_5 += tiles * 4 * 6; + } + + for (int m = 0; m < 4; m++) + { + v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0); + v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0); + v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0); + v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0); + v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0); + v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0); + + v4f32 _tmp02a = __msa_fadd_w(_tmp01, _tmp02); + v4f32 _tmp13a = __msa_fsub_w(_tmp01, _tmp02); + + v4f32 _tmp02b = __msa_fadd_w(_tmp03, _tmp04); + v4f32 _tmp13b = __msa_fsub_w(_tmp03, _tmp04); + + v4f32 _out00 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp00, _tmp02a), _tmp02b)); + v4f32 _out01 = __msa_fadd_w(_bias0, __msa_fmadd_w(_tmp13a, _v2, _tmp13b)); + v4f32 _out02 = __msa_fadd_w(_bias0, __msa_fmadd_w(_tmp02a, _v4, _tmp02b)); + v4f32 _out03 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fadd_w(_tmp05, _tmp13a), _v8, _tmp13b)); + + __msa_st_w((v4i32)_out00, output0, 0); + __msa_st_w((v4i32)_out01, output0 + 4, 0); + __msa_st_w((v4i32)_out02, output0 + 4 * 2, 0); + __msa_st_w((v4i32)_out03, output0 + 4 * 3, 0); + + output0 += outw * 4; + } + } + } + } +} diff --git a/src/layer/riscv/convolution_3x3_packn.h b/src/layer/riscv/convolution_3x3_packn.h index 534d4622a4e..e4e737ad9ed 100644 --- a/src/layer/riscv/convolution_3x3_packn.h +++ b/src/layer/riscv/convolution_3x3_packn.h @@ -95,7 +95,7 @@ static void conv3x3s1_winograd64_transform_kernel_packn_rvv(const Mat& kernel, M } } -static void conv3x3s1_winograd64_packn_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_packn_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { const int packn = csrr_vlenb() / 4; const word_type vl = vsetvl_e32m1(packn); @@ -120,169 +120,15 @@ static void conv3x3s1_winograd64_packn_rvv(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - // NOTE c99 variable length array - float tmp[8][8][packn]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * packn; - - for (int m = 0; m < 8; m++) - { - vfloat32m1_t _r00 = vle32_v_f32m1(r0, vl); - vfloat32m1_t _r01 = vle32_v_f32m1(r0 + packn, vl); - vfloat32m1_t _r02 = vle32_v_f32m1(r0 + packn * 2, vl); - vfloat32m1_t _r03 = vle32_v_f32m1(r0 + packn * 3, vl); - vfloat32m1_t _r04 = vle32_v_f32m1(r0 + packn * 4, vl); - vfloat32m1_t _r05 = vle32_v_f32m1(r0 + packn * 5, vl); - vfloat32m1_t _r06 = vle32_v_f32m1(r0 + packn * 6, vl); - vfloat32m1_t _r07 = vle32_v_f32m1(r0 + packn * 7, vl); - - vfloat32m1_t _tmp0m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r00, _r06, vl), 5.25f, vfsub_vv_f32m1(_r04, _r02, vl), vl); - vfloat32m1_t _tmp7m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r07, _r01, vl), 5.25f, vfsub_vv_f32m1(_r03, _r05, vl), vl); - vse32_v_f32m1(tmp[0][m], _tmp0m, vl); - vse32_v_f32m1(tmp[7][m], _tmp7m, vl); - - vfloat32m1_t _tmp12a = vfmacc_vf_f32m1(vfadd_vv_f32m1(_r02, _r06, vl), -4.25f, _r04, vl); - vfloat32m1_t _tmp12b = vfmacc_vf_f32m1(vfadd_vv_f32m1(_r01, _r05, vl), -4.25f, _r03, vl); - - vfloat32m1_t _tmp1m = vfadd_vv_f32m1(_tmp12a, _tmp12b, vl); - vfloat32m1_t _tmp2m = vfsub_vv_f32m1(_tmp12a, _tmp12b, vl); - vse32_v_f32m1(tmp[1][m], _tmp1m, vl); - vse32_v_f32m1(tmp[2][m], _tmp2m, vl); - - vfloat32m1_t _tmp34a = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_r06, 0.25f, _r02, vl), -1.25f, _r04, vl); - vfloat32m1_t _tmp34b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_r01, 0.5f, vl), -2.5f, _r03, vl), 2.f, _r05, vl); - - vfloat32m1_t _tmp3m = vfadd_vv_f32m1(_tmp34a, _tmp34b, vl); - vfloat32m1_t _tmp4m = vfsub_vv_f32m1(_tmp34a, _tmp34b, vl); - vse32_v_f32m1(tmp[3][m], _tmp3m, vl); - vse32_v_f32m1(tmp[4][m], _tmp4m, vl); - - vfloat32m1_t _tmp56a = vfmacc_vf_f32m1(_r06, 4.f, vfmacc_vf_f32m1(_r02, -1.25f, _r04, vl), vl); - vfloat32m1_t _tmp56b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_r01, 2.f, vl), -2.5f, _r03, vl), 0.5f, _r05, vl); - - vfloat32m1_t _tmp5m = vfadd_vv_f32m1(_tmp56a, _tmp56b, vl); - vfloat32m1_t _tmp6m = vfsub_vv_f32m1(_tmp56a, _tmp56b, vl); - vse32_v_f32m1(tmp[5][m], _tmp5m, vl); - vse32_v_f32m1(tmp[6][m], _tmp6m, vl); - - r0 += w * packn; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * packn; - float* r0_tm_1 = r0_tm_0 + tiles * packn; - float* r0_tm_2 = r0_tm_0 + tiles * packn * 2; - float* r0_tm_3 = r0_tm_0 + tiles * packn * 3; - float* r0_tm_4 = r0_tm_0 + tiles * packn * 4; - float* r0_tm_5 = r0_tm_0 + tiles * packn * 5; - float* r0_tm_6 = r0_tm_0 + tiles * packn * 6; - float* r0_tm_7 = r0_tm_0 + tiles * packn * 7; - - for (int m = 0; m < 8; m++) - { - vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); - vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); - vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); - vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); - vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); - vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); - vfloat32m1_t _tmp06 = vle32_v_f32m1(tmp[m][6], vl); - vfloat32m1_t _tmp07 = vle32_v_f32m1(tmp[m][7], vl); - - vfloat32m1_t _r0tm0 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp00, _tmp06, vl), 5.25f, vfsub_vv_f32m1(_tmp04, _tmp02, vl), vl); - vfloat32m1_t _r0tm7 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp07, _tmp01, vl), 5.25f, vfsub_vv_f32m1(_tmp03, _tmp05, vl), vl); - - vfloat32m1_t _tmp12a = vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp02, _tmp06, vl), -4.25f, _tmp04, vl); - vfloat32m1_t _tmp12b = vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp01, _tmp05, vl), -4.25f, _tmp03, vl); - - vfloat32m1_t _r0tm1 = vfadd_vv_f32m1(_tmp12a, _tmp12b, vl); - vfloat32m1_t _r0tm2 = vfsub_vv_f32m1(_tmp12a, _tmp12b, vl); - - vfloat32m1_t _tmp34a = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp06, 0.25f, _tmp02, vl), -1.25f, _tmp04, vl); - vfloat32m1_t _tmp34b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_tmp01, 0.5f, vl), -2.5f, _tmp03, vl), 2.f, _tmp05, vl); - - vfloat32m1_t _r0tm3 = vfadd_vv_f32m1(_tmp34a, _tmp34b, vl); - vfloat32m1_t _r0tm4 = vfsub_vv_f32m1(_tmp34a, _tmp34b, vl); - - vfloat32m1_t _tmp56a = vfmacc_vf_f32m1(_tmp06, 4.f, vfmacc_vf_f32m1(_tmp02, -1.25f, _tmp04, vl), vl); - vfloat32m1_t _tmp56b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_tmp01, 2.f, vl), -2.5f, _tmp03, vl), 0.5f, _tmp05, vl); - - vfloat32m1_t _r0tm5 = vfadd_vv_f32m1(_tmp56a, _tmp56b, vl); - vfloat32m1_t _r0tm6 = vfsub_vv_f32m1(_tmp56a, _tmp56b, vl); - - vse32_v_f32m1(r0_tm_0, _r0tm0, vl); - vse32_v_f32m1(r0_tm_1, _r0tm1, vl); - vse32_v_f32m1(r0_tm_2, _r0tm2, vl); - vse32_v_f32m1(r0_tm_3, _r0tm3, vl); - vse32_v_f32m1(r0_tm_4, _r0tm4, vl); - vse32_v_f32m1(r0_tm_5, _r0tm5, vl); - vse32_v_f32m1(r0_tm_6, _r0tm6, vl); - vse32_v_f32m1(r0_tm_7, _r0tm7, vl); - - r0_tm_0 += tiles * packn * 8; - r0_tm_1 += tiles * packn * 8; - r0_tm_2 += tiles * packn * 8; - r0_tm_3 += tiles * packn * 8; - r0_tm_4 += tiles * packn * 8; - r0_tm_5 += tiles * packn * 8; - r0_tm_6 += tiles * packn * 8; - r0_tm_7 += tiles * packn * 8; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_packn_rvv(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -602,140 +448,7 @@ static void conv3x3s1_winograd64_packn_rvv(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - vfloat32m1_t _bias0 = bias ? vle32_v_f32m1((const float*)bias + p * packn, vl) : vfmv_v_f_f32m1(0.f, vl); - - // NOTE c99 variable length array - float tmp[6][8][packn]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * packn; - const float* output0_tm_1 = output0_tm_0 + tiles * packn; - const float* output0_tm_2 = output0_tm_0 + tiles * packn * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * packn * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * packn * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * packn * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * packn * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * packn * 7; - - float* output0 = out0.row(i * 6) + (j * 6) * packn; - - // TODO rvv optimize - for (int m = 0; m < 8; m++) - { - vfloat32m1_t _out0tm0 = vle32_v_f32m1(output0_tm_0, vl); - vfloat32m1_t _out0tm1 = vle32_v_f32m1(output0_tm_1, vl); - vfloat32m1_t _out0tm2 = vle32_v_f32m1(output0_tm_2, vl); - vfloat32m1_t _out0tm3 = vle32_v_f32m1(output0_tm_3, vl); - vfloat32m1_t _out0tm4 = vle32_v_f32m1(output0_tm_4, vl); - vfloat32m1_t _out0tm5 = vle32_v_f32m1(output0_tm_5, vl); - vfloat32m1_t _out0tm6 = vle32_v_f32m1(output0_tm_6, vl); - vfloat32m1_t _out0tm7 = vle32_v_f32m1(output0_tm_7, vl); - - vfloat32m1_t _tmp024a = vfadd_vv_f32m1(_out0tm1, _out0tm2, vl); - vfloat32m1_t _tmp135a = vfsub_vv_f32m1(_out0tm1, _out0tm2, vl); - - vfloat32m1_t _tmp024b = vfadd_vv_f32m1(_out0tm3, _out0tm4, vl); - vfloat32m1_t _tmp135b = vfsub_vv_f32m1(_out0tm3, _out0tm4, vl); - - vfloat32m1_t _tmp024c = vfadd_vv_f32m1(_out0tm5, _out0tm6, vl); - vfloat32m1_t _tmp135c = vfsub_vv_f32m1(_out0tm5, _out0tm6, vl); - - vfloat32m1_t _tmp0m = vfadd_vv_f32m1(vfadd_vv_f32m1(_out0tm0, _tmp024a, vl), vfmacc_vf_f32m1(_tmp024b, 32.f, _tmp024c, vl), vl); - vfloat32m1_t _tmp2m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl); - vfloat32m1_t _tmp4m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl); - vse32_v_f32m1(tmp[0][m], _tmp0m, vl); - vse32_v_f32m1(tmp[2][m], _tmp2m, vl); - vse32_v_f32m1(tmp[4][m], _tmp4m, vl); - - vfloat32m1_t _tmp1m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl); - vfloat32m1_t _tmp3m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl); - vfloat32m1_t _tmp5m = vfadd_vv_f32m1(vfadd_vv_f32m1(_out0tm7, _tmp135a, vl), vfmacc_vf_f32m1(_tmp135c, 32.f, _tmp135b, vl), vl); - vse32_v_f32m1(tmp[1][m], _tmp1m, vl); - vse32_v_f32m1(tmp[3][m], _tmp3m, vl); - vse32_v_f32m1(tmp[5][m], _tmp5m, vl); - - output0_tm_0 += tiles * packn * 8; - output0_tm_1 += tiles * packn * 8; - output0_tm_2 += tiles * packn * 8; - output0_tm_3 += tiles * packn * 8; - output0_tm_4 += tiles * packn * 8; - output0_tm_5 += tiles * packn * 8; - output0_tm_6 += tiles * packn * 8; - output0_tm_7 += tiles * packn * 8; - } - - for (int m = 0; m < 6; m++) - { - vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); - vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); - vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); - vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); - vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); - vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); - vfloat32m1_t _tmp06 = vle32_v_f32m1(tmp[m][6], vl); - vfloat32m1_t _tmp07 = vle32_v_f32m1(tmp[m][7], vl); - - vfloat32m1_t _tmp024a = vfadd_vv_f32m1(_tmp01, _tmp02, vl); - vfloat32m1_t _tmp135a = vfsub_vv_f32m1(_tmp01, _tmp02, vl); - - vfloat32m1_t _tmp024b = vfadd_vv_f32m1(_tmp03, _tmp04, vl); - vfloat32m1_t _tmp135b = vfsub_vv_f32m1(_tmp03, _tmp04, vl); - - vfloat32m1_t _tmp024c = vfadd_vv_f32m1(_tmp05, _tmp06, vl); - vfloat32m1_t _tmp135c = vfsub_vv_f32m1(_tmp05, _tmp06, vl); - - vfloat32m1_t _out00 = vfadd_vv_f32m1(_bias0, vfadd_vv_f32m1(vfadd_vv_f32m1(_tmp00, _tmp024a, vl), vfmacc_vf_f32m1(_tmp024b, 32.f, _tmp024c, vl), vl), vl); - vfloat32m1_t _out02 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl), vl); - vfloat32m1_t _out04 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl), vl); - vse32_v_f32m1(output0, _out00, vl); - vse32_v_f32m1(output0 + packn * 2, _out02, vl); - vse32_v_f32m1(output0 + packn * 4, _out04, vl); - - vfloat32m1_t _out01 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl), vl); - vfloat32m1_t _out03 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl), vl); - vfloat32m1_t _out05 = vfadd_vv_f32m1(_bias0, vfadd_vv_f32m1(vfadd_vv_f32m1(_tmp07, _tmp135a, vl), vfmacc_vf_f32m1(_tmp135c, 32.f, _tmp135b, vl), vl), vl); - vse32_v_f32m1(output0 + packn, _out01, vl); - vse32_v_f32m1(output0 + packn * 3, _out03, vl); - vse32_v_f32m1(output0 + packn * 5, _out05, vl); - - output0 += outw * packn; - } - } - } - } + conv3x3s1_winograd64_transform_output_packn_rvv(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -823,7 +536,7 @@ static void conv3x3s1_winograd42_transform_kernel_packn_rvv(const Mat& kernel, M } } -static void conv3x3s1_winograd42_packn_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_packn_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { const int packn = csrr_vlenb() / 4; const word_type vl = vsetvl_e32m1(packn); @@ -848,116 +561,15 @@ static void conv3x3s1_winograd42_packn_rvv(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - // NOTE c99 variable length array - float tmp[6][6][packn]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const float* r0 = img0.row(i * 4) + (j * 4) * packn; - - for (int m = 0; m < 6; m++) - { - vfloat32m1_t _r00 = vle32_v_f32m1(r0, vl); - vfloat32m1_t _r01 = vle32_v_f32m1(r0 + packn, vl); - vfloat32m1_t _r02 = vle32_v_f32m1(r0 + packn * 2, vl); - vfloat32m1_t _r03 = vle32_v_f32m1(r0 + packn * 3, vl); - vfloat32m1_t _r04 = vle32_v_f32m1(r0 + packn * 4, vl); - vfloat32m1_t _r05 = vle32_v_f32m1(r0 + packn * 5, vl); - - vfloat32m1_t _tmp0m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_r04, 4.f, _r00, vl), -5.f, _r02, vl); - vfloat32m1_t _tmp1m = vfmacc_vf_f32m1(vfadd_vv_f32m1(_r04, _r03, vl), -4.f, vfadd_vv_f32m1(_r01, _r02, vl), vl); - vfloat32m1_t _tmp2m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r04, _r03, vl), 4.f, vfsub_vv_f32m1(_r01, _r02, vl), vl); - vfloat32m1_t _tmp3m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r04, _r02, vl), -2.f, vfsub_vv_f32m1(_r01, _r03, vl), vl); - vfloat32m1_t _tmp4m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r04, _r02, vl), 2.f, vfsub_vv_f32m1(_r01, _r03, vl), vl); - vfloat32m1_t _tmp5m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_r05, 4.f, _r01, vl), -5.f, _r03, vl); - - vse32_v_f32m1(tmp[0][m], _tmp0m, vl); - vse32_v_f32m1(tmp[1][m], _tmp1m, vl); - vse32_v_f32m1(tmp[2][m], _tmp2m, vl); - vse32_v_f32m1(tmp[3][m], _tmp3m, vl); - vse32_v_f32m1(tmp[4][m], _tmp4m, vl); - vse32_v_f32m1(tmp[5][m], _tmp5m, vl); - - r0 += w * packn; - } + int w_tiles = outw / 4; + int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * packn; - float* r0_tm_1 = r0_tm_0 + tiles * packn; - float* r0_tm_2 = r0_tm_0 + tiles * packn * 2; - float* r0_tm_3 = r0_tm_0 + tiles * packn * 3; - float* r0_tm_4 = r0_tm_0 + tiles * packn * 4; - float* r0_tm_5 = r0_tm_0 + tiles * packn * 5; - - for (int m = 0; m < 6; m++) - { - vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); - vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); - vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); - vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); - vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); - vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); - - vfloat32m1_t _r0tm0 = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp04, 4.f, _tmp00, vl), -5.f, _tmp02, vl); - vfloat32m1_t _r0tm1 = vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp04, _tmp03, vl), -4.f, vfadd_vv_f32m1(_tmp01, _tmp02, vl), vl); - vfloat32m1_t _r0tm2 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp04, _tmp03, vl), 4.f, vfsub_vv_f32m1(_tmp01, _tmp02, vl), vl); - vfloat32m1_t _r0tm3 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp04, _tmp02, vl), -2.f, vfsub_vv_f32m1(_tmp01, _tmp03, vl), vl); - vfloat32m1_t _r0tm4 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp04, _tmp02, vl), 2.f, vfsub_vv_f32m1(_tmp01, _tmp03, vl), vl); - vfloat32m1_t _r0tm5 = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp05, 4.f, _tmp01, vl), -5.f, _tmp03, vl); - - vse32_v_f32m1(r0_tm_0, _r0tm0, vl); - vse32_v_f32m1(r0_tm_1, _r0tm1, vl); - vse32_v_f32m1(r0_tm_2, _r0tm2, vl); - vse32_v_f32m1(r0_tm_3, _r0tm3, vl); - vse32_v_f32m1(r0_tm_4, _r0tm4, vl); - vse32_v_f32m1(r0_tm_5, _r0tm5, vl); - - r0_tm_0 += tiles * packn * 6; - r0_tm_1 += tiles * packn * 6; - r0_tm_2 += tiles * packn * 6; - r0_tm_3 += tiles * packn * 6; - r0_tm_4 += tiles * packn * 6; - r0_tm_5 += tiles * packn * 6; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_packn_rvv(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1277,114 +889,7 @@ static void conv3x3s1_winograd42_packn_rvv(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - vfloat32m1_t _bias0 = bias ? vle32_v_f32m1((const float*)bias + p * packn, vl) : vfmv_v_f_f32m1(0.f, vl); - - // NOTE variable length array - float tmp[4][6][packn]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * packn; - const float* output0_tm_1 = output0_tm_0 + tiles * packn; - const float* output0_tm_2 = output0_tm_0 + tiles * packn * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * packn * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * packn * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * packn * 5; - - float* output0 = out0.row(i * 4) + (j * 4) * packn; - - // TODO rvv optimize - for (int m = 0; m < 6; m++) - { - vfloat32m1_t _out0tm0 = vle32_v_f32m1(output0_tm_0, vl); - vfloat32m1_t _out0tm1 = vle32_v_f32m1(output0_tm_1, vl); - vfloat32m1_t _out0tm2 = vle32_v_f32m1(output0_tm_2, vl); - vfloat32m1_t _out0tm3 = vle32_v_f32m1(output0_tm_3, vl); - vfloat32m1_t _out0tm4 = vle32_v_f32m1(output0_tm_4, vl); - vfloat32m1_t _out0tm5 = vle32_v_f32m1(output0_tm_5, vl); - - vfloat32m1_t _tmp02a = vfadd_vv_f32m1(_out0tm1, _out0tm2, vl); - vfloat32m1_t _tmp13a = vfsub_vv_f32m1(_out0tm1, _out0tm2, vl); - - vfloat32m1_t _tmp02b = vfadd_vv_f32m1(_out0tm3, _out0tm4, vl); - vfloat32m1_t _tmp13b = vfsub_vv_f32m1(_out0tm3, _out0tm4, vl); - - vfloat32m1_t _tmp0m = vfadd_vv_f32m1(vfadd_vv_f32m1(_out0tm0, _tmp02a, vl), _tmp02b, vl); - vfloat32m1_t _tmp1m = vfmacc_vf_f32m1(_tmp13a, 2.f, _tmp13b, vl); - vfloat32m1_t _tmp2m = vfmacc_vf_f32m1(_tmp02a, 4.f, _tmp02b, vl); - vfloat32m1_t _tmp3m = vfmacc_vf_f32m1(vfadd_vv_f32m1(_out0tm5, _tmp13a, vl), 8.f, _tmp13b, vl); - - vse32_v_f32m1(tmp[0][m], _tmp0m, vl); - vse32_v_f32m1(tmp[1][m], _tmp1m, vl); - vse32_v_f32m1(tmp[2][m], _tmp2m, vl); - vse32_v_f32m1(tmp[3][m], _tmp3m, vl); - - output0_tm_0 += tiles * packn * 6; - output0_tm_1 += tiles * packn * 6; - output0_tm_2 += tiles * packn * 6; - output0_tm_3 += tiles * packn * 6; - output0_tm_4 += tiles * packn * 6; - output0_tm_5 += tiles * packn * 6; - } - - for (int m = 0; m < 4; m++) - { - vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); - vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); - vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); - vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); - vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); - vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); - - vfloat32m1_t _tmp02a = vfadd_vv_f32m1(_tmp01, _tmp02, vl); - vfloat32m1_t _tmp13a = vfsub_vv_f32m1(_tmp01, _tmp02, vl); - - vfloat32m1_t _tmp02b = vfadd_vv_f32m1(_tmp03, _tmp04, vl); - vfloat32m1_t _tmp13b = vfsub_vv_f32m1(_tmp03, _tmp04, vl); - - vfloat32m1_t _out00 = vfadd_vv_f32m1(_bias0, vfadd_vv_f32m1(vfadd_vv_f32m1(_tmp00, _tmp02a, vl), _tmp02b, vl), vl); - vfloat32m1_t _out01 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(_tmp13a, 2.f, _tmp13b, vl), vl); - vfloat32m1_t _out02 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(_tmp02a, 4.f, _tmp02b, vl), vl); - vfloat32m1_t _out03 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp05, _tmp13a, vl), 8.f, _tmp13b, vl), vl); - - vse32_v_f32m1(output0, _out00, vl); - vse32_v_f32m1(output0 + packn, _out01, vl); - vse32_v_f32m1(output0 + packn * 2, _out02, vl); - vse32_v_f32m1(output0 + packn * 3, _out03, vl); - - output0 += outw * packn; - } - } - } - } + conv3x3s1_winograd42_transform_output_packn_rvv(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/riscv/convolution_3x3_packn_fp16s.h b/src/layer/riscv/convolution_3x3_packn_fp16s.h index 26d814b0d58..6b1e6943548 100644 --- a/src/layer/riscv/convolution_3x3_packn_fp16s.h +++ b/src/layer/riscv/convolution_3x3_packn_fp16s.h @@ -95,7 +95,7 @@ static void conv3x3s1_winograd64_transform_kernel_packn_fp16sa_rvv(const Mat& ke } } -static void conv3x3s1_winograd64_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { const int packn = csrr_vlenb() / 2; const word_type vl = vsetvl_e16m1(packn); @@ -120,169 +120,15 @@ static void conv3x3s1_winograd64_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const __fp16* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - // NOTE c99 variable length array - __fp16 tmp[8][8][packn]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const __fp16* r0 = img0.row(i * 6) + (j * 6) * packn; - - for (int m = 0; m < 8; m++) - { - vfloat16m1_t _r00 = vle16_v_f16m1(r0, vl); - vfloat16m1_t _r01 = vle16_v_f16m1(r0 + packn, vl); - vfloat16m1_t _r02 = vle16_v_f16m1(r0 + packn * 2, vl); - vfloat16m1_t _r03 = vle16_v_f16m1(r0 + packn * 3, vl); - vfloat16m1_t _r04 = vle16_v_f16m1(r0 + packn * 4, vl); - vfloat16m1_t _r05 = vle16_v_f16m1(r0 + packn * 5, vl); - vfloat16m1_t _r06 = vle16_v_f16m1(r0 + packn * 6, vl); - vfloat16m1_t _r07 = vle16_v_f16m1(r0 + packn * 7, vl); - - vfloat16m1_t _tmp0m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r00, _r06, vl), 5.25f, vfsub_vv_f16m1(_r04, _r02, vl), vl); - vfloat16m1_t _tmp7m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r07, _r01, vl), 5.25f, vfsub_vv_f16m1(_r03, _r05, vl), vl); - vse16_v_f16m1(tmp[0][m], _tmp0m, vl); - vse16_v_f16m1(tmp[7][m], _tmp7m, vl); - - vfloat16m1_t _tmp12a = vfmacc_vf_f16m1(vfadd_vv_f16m1(_r02, _r06, vl), -4.25f, _r04, vl); - vfloat16m1_t _tmp12b = vfmacc_vf_f16m1(vfadd_vv_f16m1(_r01, _r05, vl), -4.25f, _r03, vl); - - vfloat16m1_t _tmp1m = vfadd_vv_f16m1(_tmp12a, _tmp12b, vl); - vfloat16m1_t _tmp2m = vfsub_vv_f16m1(_tmp12a, _tmp12b, vl); - vse16_v_f16m1(tmp[1][m], _tmp1m, vl); - vse16_v_f16m1(tmp[2][m], _tmp2m, vl); - - vfloat16m1_t _tmp34a = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_r06, 0.25f, _r02, vl), -1.25f, _r04, vl); - vfloat16m1_t _tmp34b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_r01, 0.5f, vl), -2.5f, _r03, vl), 2.f, _r05, vl); - - vfloat16m1_t _tmp3m = vfadd_vv_f16m1(_tmp34a, _tmp34b, vl); - vfloat16m1_t _tmp4m = vfsub_vv_f16m1(_tmp34a, _tmp34b, vl); - vse16_v_f16m1(tmp[3][m], _tmp3m, vl); - vse16_v_f16m1(tmp[4][m], _tmp4m, vl); - - vfloat16m1_t _tmp56a = vfmacc_vf_f16m1(_r06, 4.f, vfmacc_vf_f16m1(_r02, -1.25f, _r04, vl), vl); - vfloat16m1_t _tmp56b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_r01, 2.f, vl), -2.5f, _r03, vl), 0.5f, _r05, vl); - - vfloat16m1_t _tmp5m = vfadd_vv_f16m1(_tmp56a, _tmp56b, vl); - vfloat16m1_t _tmp6m = vfsub_vv_f16m1(_tmp56a, _tmp56b, vl); - vse16_v_f16m1(tmp[5][m], _tmp5m, vl); - vse16_v_f16m1(tmp[6][m], _tmp6m, vl); - - r0 += w * packn; - } - - __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * packn; - __fp16* r0_tm_1 = r0_tm_0 + tiles * packn; - __fp16* r0_tm_2 = r0_tm_0 + tiles * packn * 2; - __fp16* r0_tm_3 = r0_tm_0 + tiles * packn * 3; - __fp16* r0_tm_4 = r0_tm_0 + tiles * packn * 4; - __fp16* r0_tm_5 = r0_tm_0 + tiles * packn * 5; - __fp16* r0_tm_6 = r0_tm_0 + tiles * packn * 6; - __fp16* r0_tm_7 = r0_tm_0 + tiles * packn * 7; - - for (int m = 0; m < 8; m++) - { - vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); - vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); - vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); - vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); - vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); - vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); - vfloat16m1_t _tmp06 = vle16_v_f16m1(tmp[m][6], vl); - vfloat16m1_t _tmp07 = vle16_v_f16m1(tmp[m][7], vl); - - vfloat16m1_t _r0tm0 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp00, _tmp06, vl), 5.25f, vfsub_vv_f16m1(_tmp04, _tmp02, vl), vl); - vfloat16m1_t _r0tm7 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp07, _tmp01, vl), 5.25f, vfsub_vv_f16m1(_tmp03, _tmp05, vl), vl); - - vfloat16m1_t _tmp12a = vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp02, _tmp06, vl), -4.25f, _tmp04, vl); - vfloat16m1_t _tmp12b = vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp01, _tmp05, vl), -4.25f, _tmp03, vl); - - vfloat16m1_t _r0tm1 = vfadd_vv_f16m1(_tmp12a, _tmp12b, vl); - vfloat16m1_t _r0tm2 = vfsub_vv_f16m1(_tmp12a, _tmp12b, vl); - - vfloat16m1_t _tmp34a = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp06, 0.25f, _tmp02, vl), -1.25f, _tmp04, vl); - vfloat16m1_t _tmp34b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_tmp01, 0.5f, vl), -2.5f, _tmp03, vl), 2.f, _tmp05, vl); - - vfloat16m1_t _r0tm3 = vfadd_vv_f16m1(_tmp34a, _tmp34b, vl); - vfloat16m1_t _r0tm4 = vfsub_vv_f16m1(_tmp34a, _tmp34b, vl); - - vfloat16m1_t _tmp56a = vfmacc_vf_f16m1(_tmp06, 4.f, vfmacc_vf_f16m1(_tmp02, -1.25f, _tmp04, vl), vl); - vfloat16m1_t _tmp56b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_tmp01, 2.f, vl), -2.5f, _tmp03, vl), 0.5f, _tmp05, vl); - - vfloat16m1_t _r0tm5 = vfadd_vv_f16m1(_tmp56a, _tmp56b, vl); - vfloat16m1_t _r0tm6 = vfsub_vv_f16m1(_tmp56a, _tmp56b, vl); - - vse16_v_f16m1(r0_tm_0, _r0tm0, vl); - vse16_v_f16m1(r0_tm_1, _r0tm1, vl); - vse16_v_f16m1(r0_tm_2, _r0tm2, vl); - vse16_v_f16m1(r0_tm_3, _r0tm3, vl); - vse16_v_f16m1(r0_tm_4, _r0tm4, vl); - vse16_v_f16m1(r0_tm_5, _r0tm5, vl); - vse16_v_f16m1(r0_tm_6, _r0tm6, vl); - vse16_v_f16m1(r0_tm_7, _r0tm7, vl); - - r0_tm_0 += tiles * packn * 8; - r0_tm_1 += tiles * packn * 8; - r0_tm_2 += tiles * packn * 8; - r0_tm_3 += tiles * packn * 8; - r0_tm_4 += tiles * packn * 8; - r0_tm_5 += tiles * packn * 8; - r0_tm_6 += tiles * packn * 8; - r0_tm_7 += tiles * packn * 8; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_packn_fp16sa_rvv(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -602,140 +448,7 @@ static void conv3x3s1_winograd64_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - vfloat16m1_t _bias0 = bias ? vle16_v_f16m1((const __fp16*)bias + p * packn, vl) : vfmv_v_f_f16m1(0.f, vl); - - // NOTE c99 variable length array - __fp16 tmp[6][8][packn]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * packn; - const __fp16* output0_tm_1 = output0_tm_0 + tiles * packn; - const __fp16* output0_tm_2 = output0_tm_0 + tiles * packn * 2; - const __fp16* output0_tm_3 = output0_tm_0 + tiles * packn * 3; - const __fp16* output0_tm_4 = output0_tm_0 + tiles * packn * 4; - const __fp16* output0_tm_5 = output0_tm_0 + tiles * packn * 5; - const __fp16* output0_tm_6 = output0_tm_0 + tiles * packn * 6; - const __fp16* output0_tm_7 = output0_tm_0 + tiles * packn * 7; - - __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * packn; - - // TODO rvv optimize - for (int m = 0; m < 8; m++) - { - vfloat16m1_t _out0tm0 = vle16_v_f16m1(output0_tm_0, vl); - vfloat16m1_t _out0tm1 = vle16_v_f16m1(output0_tm_1, vl); - vfloat16m1_t _out0tm2 = vle16_v_f16m1(output0_tm_2, vl); - vfloat16m1_t _out0tm3 = vle16_v_f16m1(output0_tm_3, vl); - vfloat16m1_t _out0tm4 = vle16_v_f16m1(output0_tm_4, vl); - vfloat16m1_t _out0tm5 = vle16_v_f16m1(output0_tm_5, vl); - vfloat16m1_t _out0tm6 = vle16_v_f16m1(output0_tm_6, vl); - vfloat16m1_t _out0tm7 = vle16_v_f16m1(output0_tm_7, vl); - - vfloat16m1_t _tmp024a = vfadd_vv_f16m1(_out0tm1, _out0tm2, vl); - vfloat16m1_t _tmp135a = vfsub_vv_f16m1(_out0tm1, _out0tm2, vl); - - vfloat16m1_t _tmp024b = vfadd_vv_f16m1(_out0tm3, _out0tm4, vl); - vfloat16m1_t _tmp135b = vfsub_vv_f16m1(_out0tm3, _out0tm4, vl); - - vfloat16m1_t _tmp024c = vfadd_vv_f16m1(_out0tm5, _out0tm6, vl); - vfloat16m1_t _tmp135c = vfsub_vv_f16m1(_out0tm5, _out0tm6, vl); - - vfloat16m1_t _tmp0m = vfadd_vv_f16m1(vfadd_vv_f16m1(_out0tm0, _tmp024a, vl), vfmacc_vf_f16m1(_tmp024b, 32.f, _tmp024c, vl), vl); - vfloat16m1_t _tmp2m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl); - vfloat16m1_t _tmp4m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl); - vse16_v_f16m1(tmp[0][m], _tmp0m, vl); - vse16_v_f16m1(tmp[2][m], _tmp2m, vl); - vse16_v_f16m1(tmp[4][m], _tmp4m, vl); - - vfloat16m1_t _tmp1m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl); - vfloat16m1_t _tmp3m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl); - vfloat16m1_t _tmp5m = vfadd_vv_f16m1(vfadd_vv_f16m1(_out0tm7, _tmp135a, vl), vfmacc_vf_f16m1(_tmp135c, 32.f, _tmp135b, vl), vl); - vse16_v_f16m1(tmp[1][m], _tmp1m, vl); - vse16_v_f16m1(tmp[3][m], _tmp3m, vl); - vse16_v_f16m1(tmp[5][m], _tmp5m, vl); - - output0_tm_0 += tiles * packn * 8; - output0_tm_1 += tiles * packn * 8; - output0_tm_2 += tiles * packn * 8; - output0_tm_3 += tiles * packn * 8; - output0_tm_4 += tiles * packn * 8; - output0_tm_5 += tiles * packn * 8; - output0_tm_6 += tiles * packn * 8; - output0_tm_7 += tiles * packn * 8; - } - - for (int m = 0; m < 6; m++) - { - vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); - vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); - vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); - vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); - vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); - vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); - vfloat16m1_t _tmp06 = vle16_v_f16m1(tmp[m][6], vl); - vfloat16m1_t _tmp07 = vle16_v_f16m1(tmp[m][7], vl); - - vfloat16m1_t _tmp024a = vfadd_vv_f16m1(_tmp01, _tmp02, vl); - vfloat16m1_t _tmp135a = vfsub_vv_f16m1(_tmp01, _tmp02, vl); - - vfloat16m1_t _tmp024b = vfadd_vv_f16m1(_tmp03, _tmp04, vl); - vfloat16m1_t _tmp135b = vfsub_vv_f16m1(_tmp03, _tmp04, vl); - - vfloat16m1_t _tmp024c = vfadd_vv_f16m1(_tmp05, _tmp06, vl); - vfloat16m1_t _tmp135c = vfsub_vv_f16m1(_tmp05, _tmp06, vl); - - vfloat16m1_t _out00 = vfadd_vv_f16m1(_bias0, vfadd_vv_f16m1(vfadd_vv_f16m1(_tmp00, _tmp024a, vl), vfmacc_vf_f16m1(_tmp024b, 32.f, _tmp024c, vl), vl), vl); - vfloat16m1_t _out02 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl), vl); - vfloat16m1_t _out04 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl), vl); - vse16_v_f16m1(output0, _out00, vl); - vse16_v_f16m1(output0 + packn * 2, _out02, vl); - vse16_v_f16m1(output0 + packn * 4, _out04, vl); - - vfloat16m1_t _out01 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl), vl); - vfloat16m1_t _out03 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl), vl); - vfloat16m1_t _out05 = vfadd_vv_f16m1(_bias0, vfadd_vv_f16m1(vfadd_vv_f16m1(_tmp07, _tmp135a, vl), vfmacc_vf_f16m1(_tmp135c, 32.f, _tmp135b, vl), vl), vl); - vse16_v_f16m1(output0 + packn, _out01, vl); - vse16_v_f16m1(output0 + packn * 3, _out03, vl); - vse16_v_f16m1(output0 + packn * 5, _out05, vl); - - output0 += outw * packn; - } - } - } - } + conv3x3s1_winograd64_transform_output_packn_fp16sa_rvv(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -823,7 +536,7 @@ static void conv3x3s1_winograd42_transform_kernel_packn_fp16sa_rvv(const Mat& ke } } -static void conv3x3s1_winograd42_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { const int packn = csrr_vlenb() / 2; const word_type vl = vsetvl_e16m1(packn); @@ -848,116 +561,15 @@ static void conv3x3s1_winograd42_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const __fp16* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 2u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - - // NOTE c99 variable length array - __fp16 tmp[6][6][packn]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const __fp16* r0 = img0.row(i * 4) + (j * 4) * packn; - - for (int m = 0; m < 6; m++) - { - vfloat16m1_t _r00 = vle16_v_f16m1(r0, vl); - vfloat16m1_t _r01 = vle16_v_f16m1(r0 + packn, vl); - vfloat16m1_t _r02 = vle16_v_f16m1(r0 + packn * 2, vl); - vfloat16m1_t _r03 = vle16_v_f16m1(r0 + packn * 3, vl); - vfloat16m1_t _r04 = vle16_v_f16m1(r0 + packn * 4, vl); - vfloat16m1_t _r05 = vle16_v_f16m1(r0 + packn * 5, vl); - - vfloat16m1_t _tmp0m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_r04, 4.f, _r00, vl), -5.f, _r02, vl); - vfloat16m1_t _tmp1m = vfmacc_vf_f16m1(vfadd_vv_f16m1(_r04, _r03, vl), -4.f, vfadd_vv_f16m1(_r01, _r02, vl), vl); - vfloat16m1_t _tmp2m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r04, _r03, vl), 4.f, vfsub_vv_f16m1(_r01, _r02, vl), vl); - vfloat16m1_t _tmp3m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r04, _r02, vl), -2.f, vfsub_vv_f16m1(_r01, _r03, vl), vl); - vfloat16m1_t _tmp4m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r04, _r02, vl), 2.f, vfsub_vv_f16m1(_r01, _r03, vl), vl); - vfloat16m1_t _tmp5m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_r05, 4.f, _r01, vl), -5.f, _r03, vl); - - vse16_v_f16m1(tmp[0][m], _tmp0m, vl); - vse16_v_f16m1(tmp[1][m], _tmp1m, vl); - vse16_v_f16m1(tmp[2][m], _tmp2m, vl); - vse16_v_f16m1(tmp[3][m], _tmp3m, vl); - vse16_v_f16m1(tmp[4][m], _tmp4m, vl); - vse16_v_f16m1(tmp[5][m], _tmp5m, vl); - - r0 += w * packn; - } + int w_tiles = outw / 4; + int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; - __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 6 + j) * packn; - __fp16* r0_tm_1 = r0_tm_0 + tiles * packn; - __fp16* r0_tm_2 = r0_tm_0 + tiles * packn * 2; - __fp16* r0_tm_3 = r0_tm_0 + tiles * packn * 3; - __fp16* r0_tm_4 = r0_tm_0 + tiles * packn * 4; - __fp16* r0_tm_5 = r0_tm_0 + tiles * packn * 5; - - for (int m = 0; m < 6; m++) - { - vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); - vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); - vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); - vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); - vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); - vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); - - vfloat16m1_t _r0tm0 = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp04, 4.f, _tmp00, vl), -5.f, _tmp02, vl); - vfloat16m1_t _r0tm1 = vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp04, _tmp03, vl), -4.f, vfadd_vv_f16m1(_tmp01, _tmp02, vl), vl); - vfloat16m1_t _r0tm2 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp04, _tmp03, vl), 4.f, vfsub_vv_f16m1(_tmp01, _tmp02, vl), vl); - vfloat16m1_t _r0tm3 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp04, _tmp02, vl), -2.f, vfsub_vv_f16m1(_tmp01, _tmp03, vl), vl); - vfloat16m1_t _r0tm4 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp04, _tmp02, vl), 2.f, vfsub_vv_f16m1(_tmp01, _tmp03, vl), vl); - vfloat16m1_t _r0tm5 = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp05, 4.f, _tmp01, vl), -5.f, _tmp03, vl); - - vse16_v_f16m1(r0_tm_0, _r0tm0, vl); - vse16_v_f16m1(r0_tm_1, _r0tm1, vl); - vse16_v_f16m1(r0_tm_2, _r0tm2, vl); - vse16_v_f16m1(r0_tm_3, _r0tm3, vl); - vse16_v_f16m1(r0_tm_4, _r0tm4, vl); - vse16_v_f16m1(r0_tm_5, _r0tm5, vl); - - r0_tm_0 += tiles * packn * 6; - r0_tm_1 += tiles * packn * 6; - r0_tm_2 += tiles * packn * 6; - r0_tm_3 += tiles * packn * 6; - r0_tm_4 += tiles * packn * 6; - r0_tm_5 += tiles * packn * 6; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_packn_fp16sa_rvv(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1277,114 +889,7 @@ static void conv3x3s1_winograd42_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& t top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - vfloat16m1_t _bias0 = bias ? vle16_v_f16m1((const __fp16*)bias + p * packn, vl) : vfmv_v_f_f16m1(0.f, vl); - - // NOTE variable length array - __fp16 tmp[4][6][packn]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 6 + j) * packn; - const __fp16* output0_tm_1 = output0_tm_0 + tiles * packn; - const __fp16* output0_tm_2 = output0_tm_0 + tiles * packn * 2; - const __fp16* output0_tm_3 = output0_tm_0 + tiles * packn * 3; - const __fp16* output0_tm_4 = output0_tm_0 + tiles * packn * 4; - const __fp16* output0_tm_5 = output0_tm_0 + tiles * packn * 5; - - __fp16* output0 = out0.row<__fp16>(i * 4) + (j * 4) * packn; - - // TODO rvv optimize - for (int m = 0; m < 6; m++) - { - vfloat16m1_t _out0tm0 = vle16_v_f16m1(output0_tm_0, vl); - vfloat16m1_t _out0tm1 = vle16_v_f16m1(output0_tm_1, vl); - vfloat16m1_t _out0tm2 = vle16_v_f16m1(output0_tm_2, vl); - vfloat16m1_t _out0tm3 = vle16_v_f16m1(output0_tm_3, vl); - vfloat16m1_t _out0tm4 = vle16_v_f16m1(output0_tm_4, vl); - vfloat16m1_t _out0tm5 = vle16_v_f16m1(output0_tm_5, vl); - - vfloat16m1_t _tmp02a = vfadd_vv_f16m1(_out0tm1, _out0tm2, vl); - vfloat16m1_t _tmp13a = vfsub_vv_f16m1(_out0tm1, _out0tm2, vl); - - vfloat16m1_t _tmp02b = vfadd_vv_f16m1(_out0tm3, _out0tm4, vl); - vfloat16m1_t _tmp13b = vfsub_vv_f16m1(_out0tm3, _out0tm4, vl); - - vfloat16m1_t _tmp0m = vfadd_vv_f16m1(vfadd_vv_f16m1(_out0tm0, _tmp02a, vl), _tmp02b, vl); - vfloat16m1_t _tmp1m = vfmacc_vf_f16m1(_tmp13a, 2.f, _tmp13b, vl); - vfloat16m1_t _tmp2m = vfmacc_vf_f16m1(_tmp02a, 4.f, _tmp02b, vl); - vfloat16m1_t _tmp3m = vfmacc_vf_f16m1(vfadd_vv_f16m1(_out0tm5, _tmp13a, vl), 8.f, _tmp13b, vl); - - vse16_v_f16m1(tmp[0][m], _tmp0m, vl); - vse16_v_f16m1(tmp[1][m], _tmp1m, vl); - vse16_v_f16m1(tmp[2][m], _tmp2m, vl); - vse16_v_f16m1(tmp[3][m], _tmp3m, vl); - - output0_tm_0 += tiles * packn * 6; - output0_tm_1 += tiles * packn * 6; - output0_tm_2 += tiles * packn * 6; - output0_tm_3 += tiles * packn * 6; - output0_tm_4 += tiles * packn * 6; - output0_tm_5 += tiles * packn * 6; - } - - for (int m = 0; m < 4; m++) - { - vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); - vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); - vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); - vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); - vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); - vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); - - vfloat16m1_t _tmp02a = vfadd_vv_f16m1(_tmp01, _tmp02, vl); - vfloat16m1_t _tmp13a = vfsub_vv_f16m1(_tmp01, _tmp02, vl); - - vfloat16m1_t _tmp02b = vfadd_vv_f16m1(_tmp03, _tmp04, vl); - vfloat16m1_t _tmp13b = vfsub_vv_f16m1(_tmp03, _tmp04, vl); - - vfloat16m1_t _out00 = vfadd_vv_f16m1(_bias0, vfadd_vv_f16m1(vfadd_vv_f16m1(_tmp00, _tmp02a, vl), _tmp02b, vl), vl); - vfloat16m1_t _out01 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(_tmp13a, 2.f, _tmp13b, vl), vl); - vfloat16m1_t _out02 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(_tmp02a, 4.f, _tmp02b, vl), vl); - vfloat16m1_t _out03 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp05, _tmp13a, vl), 8.f, _tmp13b, vl), vl); - - vse16_v_f16m1(output0, _out00, vl); - vse16_v_f16m1(output0 + packn, _out01, vl); - vse16_v_f16m1(output0 + packn * 2, _out02, vl); - vse16_v_f16m1(output0 + packn * 3, _out03, vl); - - output0 += outw * packn; - } - } - } - } + conv3x3s1_winograd42_transform_output_packn_fp16sa_rvv(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/riscv/convolution_riscv.cpp b/src/layer/riscv/convolution_riscv.cpp index 38d73e07fea..ddddaff2375 100644 --- a/src/layer/riscv/convolution_riscv.cpp +++ b/src/layer/riscv/convolution_riscv.cpp @@ -44,6 +44,7 @@ namespace ncnn { #include "convolution_sgemm_packn.h" #include "convolution_sgemm_pack1ton.h" #include "convolution_sgemm_packnto1.h" +#include "convolution_winograd_transform_packn.h" #include "convolution_1x1_packn.h" #include "convolution_1x1_pack1ton.h" #include "convolution_1x1_packnto1.h" @@ -61,6 +62,7 @@ namespace ncnn { #include "convolution_sgemm_packn_fp16s.h" #include "convolution_sgemm_pack1ton_fp16s.h" #include "convolution_sgemm_packnto1_fp16s.h" +#include "convolution_winograd_transform_packn_fp16s.h" #include "convolution_1x1_fp16s.h" #include "convolution_1x1_packn_fp16s.h" #include "convolution_1x1_pack1ton_fp16s.h" diff --git a/src/layer/riscv/convolution_winograd_transform_packn.h b/src/layer/riscv/convolution_winograd_transform_packn.h new file mode 100644 index 00000000000..e0a947d9df9 --- /dev/null +++ b/src/layer/riscv/convolution_winograd_transform_packn.h @@ -0,0 +1,551 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_packn_rvv(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int packn = csrr_vlenb() / 4; + const word_type vl = vsetvl_e32m1(packn); + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + // NOTE c99 variable length array + float tmp[8][8][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 6) + (j * 6) * packn; + + for (int m = 0; m < 8; m++) + { + vfloat32m1_t _r00 = vle32_v_f32m1(r0, vl); + vfloat32m1_t _r01 = vle32_v_f32m1(r0 + packn, vl); + vfloat32m1_t _r02 = vle32_v_f32m1(r0 + packn * 2, vl); + vfloat32m1_t _r03 = vle32_v_f32m1(r0 + packn * 3, vl); + vfloat32m1_t _r04 = vle32_v_f32m1(r0 + packn * 4, vl); + vfloat32m1_t _r05 = vle32_v_f32m1(r0 + packn * 5, vl); + vfloat32m1_t _r06 = vle32_v_f32m1(r0 + packn * 6, vl); + vfloat32m1_t _r07 = vle32_v_f32m1(r0 + packn * 7, vl); + + vfloat32m1_t _tmp0m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r00, _r06, vl), 5.25f, vfsub_vv_f32m1(_r04, _r02, vl), vl); + vfloat32m1_t _tmp7m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r07, _r01, vl), 5.25f, vfsub_vv_f32m1(_r03, _r05, vl), vl); + vse32_v_f32m1(tmp[0][m], _tmp0m, vl); + vse32_v_f32m1(tmp[7][m], _tmp7m, vl); + + vfloat32m1_t _tmp12a = vfmacc_vf_f32m1(vfadd_vv_f32m1(_r02, _r06, vl), -4.25f, _r04, vl); + vfloat32m1_t _tmp12b = vfmacc_vf_f32m1(vfadd_vv_f32m1(_r01, _r05, vl), -4.25f, _r03, vl); + + vfloat32m1_t _tmp1m = vfadd_vv_f32m1(_tmp12a, _tmp12b, vl); + vfloat32m1_t _tmp2m = vfsub_vv_f32m1(_tmp12a, _tmp12b, vl); + vse32_v_f32m1(tmp[1][m], _tmp1m, vl); + vse32_v_f32m1(tmp[2][m], _tmp2m, vl); + + vfloat32m1_t _tmp34a = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_r06, 0.25f, _r02, vl), -1.25f, _r04, vl); + vfloat32m1_t _tmp34b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_r01, 0.5f, vl), -2.5f, _r03, vl), 2.f, _r05, vl); + + vfloat32m1_t _tmp3m = vfadd_vv_f32m1(_tmp34a, _tmp34b, vl); + vfloat32m1_t _tmp4m = vfsub_vv_f32m1(_tmp34a, _tmp34b, vl); + vse32_v_f32m1(tmp[3][m], _tmp3m, vl); + vse32_v_f32m1(tmp[4][m], _tmp4m, vl); + + vfloat32m1_t _tmp56a = vfmacc_vf_f32m1(_r06, 4.f, vfmacc_vf_f32m1(_r02, -1.25f, _r04, vl), vl); + vfloat32m1_t _tmp56b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_r01, 2.f, vl), -2.5f, _r03, vl), 0.5f, _r05, vl); + + vfloat32m1_t _tmp5m = vfadd_vv_f32m1(_tmp56a, _tmp56b, vl); + vfloat32m1_t _tmp6m = vfsub_vv_f32m1(_tmp56a, _tmp56b, vl); + vse32_v_f32m1(tmp[5][m], _tmp5m, vl); + vse32_v_f32m1(tmp[6][m], _tmp6m, vl); + + r0 += w * packn; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * packn; + float* r0_tm_1 = r0_tm_0 + tiles * packn; + float* r0_tm_2 = r0_tm_0 + tiles * packn * 2; + float* r0_tm_3 = r0_tm_0 + tiles * packn * 3; + float* r0_tm_4 = r0_tm_0 + tiles * packn * 4; + float* r0_tm_5 = r0_tm_0 + tiles * packn * 5; + float* r0_tm_6 = r0_tm_0 + tiles * packn * 6; + float* r0_tm_7 = r0_tm_0 + tiles * packn * 7; + + for (int m = 0; m < 8; m++) + { + vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); + vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); + vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); + vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); + vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); + vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); + vfloat32m1_t _tmp06 = vle32_v_f32m1(tmp[m][6], vl); + vfloat32m1_t _tmp07 = vle32_v_f32m1(tmp[m][7], vl); + + vfloat32m1_t _r0tm0 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp00, _tmp06, vl), 5.25f, vfsub_vv_f32m1(_tmp04, _tmp02, vl), vl); + vfloat32m1_t _r0tm7 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp07, _tmp01, vl), 5.25f, vfsub_vv_f32m1(_tmp03, _tmp05, vl), vl); + + vfloat32m1_t _tmp12a = vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp02, _tmp06, vl), -4.25f, _tmp04, vl); + vfloat32m1_t _tmp12b = vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp01, _tmp05, vl), -4.25f, _tmp03, vl); + + vfloat32m1_t _r0tm1 = vfadd_vv_f32m1(_tmp12a, _tmp12b, vl); + vfloat32m1_t _r0tm2 = vfsub_vv_f32m1(_tmp12a, _tmp12b, vl); + + vfloat32m1_t _tmp34a = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp06, 0.25f, _tmp02, vl), -1.25f, _tmp04, vl); + vfloat32m1_t _tmp34b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_tmp01, 0.5f, vl), -2.5f, _tmp03, vl), 2.f, _tmp05, vl); + + vfloat32m1_t _r0tm3 = vfadd_vv_f32m1(_tmp34a, _tmp34b, vl); + vfloat32m1_t _r0tm4 = vfsub_vv_f32m1(_tmp34a, _tmp34b, vl); + + vfloat32m1_t _tmp56a = vfmacc_vf_f32m1(_tmp06, 4.f, vfmacc_vf_f32m1(_tmp02, -1.25f, _tmp04, vl), vl); + vfloat32m1_t _tmp56b = vfmacc_vf_f32m1(vfmacc_vf_f32m1(vfmul_vf_f32m1(_tmp01, 2.f, vl), -2.5f, _tmp03, vl), 0.5f, _tmp05, vl); + + vfloat32m1_t _r0tm5 = vfadd_vv_f32m1(_tmp56a, _tmp56b, vl); + vfloat32m1_t _r0tm6 = vfsub_vv_f32m1(_tmp56a, _tmp56b, vl); + + vse32_v_f32m1(r0_tm_0, _r0tm0, vl); + vse32_v_f32m1(r0_tm_1, _r0tm1, vl); + vse32_v_f32m1(r0_tm_2, _r0tm2, vl); + vse32_v_f32m1(r0_tm_3, _r0tm3, vl); + vse32_v_f32m1(r0_tm_4, _r0tm4, vl); + vse32_v_f32m1(r0_tm_5, _r0tm5, vl); + vse32_v_f32m1(r0_tm_6, _r0tm6, vl); + vse32_v_f32m1(r0_tm_7, _r0tm7, vl); + + r0_tm_0 += tiles * packn * 8; + r0_tm_1 += tiles * packn * 8; + r0_tm_2 += tiles * packn * 8; + r0_tm_3 += tiles * packn * 8; + r0_tm_4 += tiles * packn * 8; + r0_tm_5 += tiles * packn * 8; + r0_tm_6 += tiles * packn * 8; + r0_tm_7 += tiles * packn * 8; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_packn_rvv(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int packn = csrr_vlenb() / 4; + const word_type vl = vsetvl_e32m1(packn); + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + vfloat32m1_t _bias0 = biasptr ? vle32_v_f32m1(biasptr + p * packn, vl) : vfmv_v_f_f32m1(0.f, vl); + + // NOTE c99 variable length array + float tmp[6][8][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * packn; + const float* output0_tm_1 = output0_tm_0 + tiles * packn; + const float* output0_tm_2 = output0_tm_0 + tiles * packn * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * packn * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * packn * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * packn * 5; + const float* output0_tm_6 = output0_tm_0 + tiles * packn * 6; + const float* output0_tm_7 = output0_tm_0 + tiles * packn * 7; + + float* output0 = out0.row(i * 6) + (j * 6) * packn; + + for (int m = 0; m < 8; m++) + { + vfloat32m1_t _out0tm0 = vle32_v_f32m1(output0_tm_0, vl); + vfloat32m1_t _out0tm1 = vle32_v_f32m1(output0_tm_1, vl); + vfloat32m1_t _out0tm2 = vle32_v_f32m1(output0_tm_2, vl); + vfloat32m1_t _out0tm3 = vle32_v_f32m1(output0_tm_3, vl); + vfloat32m1_t _out0tm4 = vle32_v_f32m1(output0_tm_4, vl); + vfloat32m1_t _out0tm5 = vle32_v_f32m1(output0_tm_5, vl); + vfloat32m1_t _out0tm6 = vle32_v_f32m1(output0_tm_6, vl); + vfloat32m1_t _out0tm7 = vle32_v_f32m1(output0_tm_7, vl); + + vfloat32m1_t _tmp024a = vfadd_vv_f32m1(_out0tm1, _out0tm2, vl); + vfloat32m1_t _tmp135a = vfsub_vv_f32m1(_out0tm1, _out0tm2, vl); + + vfloat32m1_t _tmp024b = vfadd_vv_f32m1(_out0tm3, _out0tm4, vl); + vfloat32m1_t _tmp135b = vfsub_vv_f32m1(_out0tm3, _out0tm4, vl); + + vfloat32m1_t _tmp024c = vfadd_vv_f32m1(_out0tm5, _out0tm6, vl); + vfloat32m1_t _tmp135c = vfsub_vv_f32m1(_out0tm5, _out0tm6, vl); + + vfloat32m1_t _tmp0m = vfadd_vv_f32m1(vfadd_vv_f32m1(_out0tm0, _tmp024a, vl), vfmacc_vf_f32m1(_tmp024b, 32.f, _tmp024c, vl), vl); + vfloat32m1_t _tmp2m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl); + vfloat32m1_t _tmp4m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl); + vse32_v_f32m1(tmp[0][m], _tmp0m, vl); + vse32_v_f32m1(tmp[2][m], _tmp2m, vl); + vse32_v_f32m1(tmp[4][m], _tmp4m, vl); + + vfloat32m1_t _tmp1m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl); + vfloat32m1_t _tmp3m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl); + vfloat32m1_t _tmp5m = vfadd_vv_f32m1(vfadd_vv_f32m1(_out0tm7, _tmp135a, vl), vfmacc_vf_f32m1(_tmp135c, 32.f, _tmp135b, vl), vl); + vse32_v_f32m1(tmp[1][m], _tmp1m, vl); + vse32_v_f32m1(tmp[3][m], _tmp3m, vl); + vse32_v_f32m1(tmp[5][m], _tmp5m, vl); + + output0_tm_0 += tiles * packn * 8; + output0_tm_1 += tiles * packn * 8; + output0_tm_2 += tiles * packn * 8; + output0_tm_3 += tiles * packn * 8; + output0_tm_4 += tiles * packn * 8; + output0_tm_5 += tiles * packn * 8; + output0_tm_6 += tiles * packn * 8; + output0_tm_7 += tiles * packn * 8; + } + + for (int m = 0; m < 6; m++) + { + vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); + vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); + vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); + vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); + vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); + vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); + vfloat32m1_t _tmp06 = vle32_v_f32m1(tmp[m][6], vl); + vfloat32m1_t _tmp07 = vle32_v_f32m1(tmp[m][7], vl); + + vfloat32m1_t _tmp024a = vfadd_vv_f32m1(_tmp01, _tmp02, vl); + vfloat32m1_t _tmp135a = vfsub_vv_f32m1(_tmp01, _tmp02, vl); + + vfloat32m1_t _tmp024b = vfadd_vv_f32m1(_tmp03, _tmp04, vl); + vfloat32m1_t _tmp135b = vfsub_vv_f32m1(_tmp03, _tmp04, vl); + + vfloat32m1_t _tmp024c = vfadd_vv_f32m1(_tmp05, _tmp06, vl); + vfloat32m1_t _tmp135c = vfsub_vv_f32m1(_tmp05, _tmp06, vl); + + vfloat32m1_t _out00 = vfadd_vv_f32m1(_bias0, vfadd_vv_f32m1(vfadd_vv_f32m1(_tmp00, _tmp024a, vl), vfmacc_vf_f32m1(_tmp024b, 32.f, _tmp024c, vl), vl), vl); + vfloat32m1_t _out02 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl), vl); + vfloat32m1_t _out04 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl), vl); + vse32_v_f32m1(output0, _out00, vl); + vse32_v_f32m1(output0 + packn * 2, _out02, vl); + vse32_v_f32m1(output0 + packn * 4, _out04, vl); + + vfloat32m1_t _out01 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl), vl); + vfloat32m1_t _out03 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl), vl); + vfloat32m1_t _out05 = vfadd_vv_f32m1(_bias0, vfadd_vv_f32m1(vfadd_vv_f32m1(_tmp07, _tmp135a, vl), vfmacc_vf_f32m1(_tmp135c, 32.f, _tmp135b, vl), vl), vl); + vse32_v_f32m1(output0 + packn, _out01, vl); + vse32_v_f32m1(output0 + packn * 3, _out03, vl); + vse32_v_f32m1(output0 + packn * 5, _out05, vl); + + output0 += outw * packn; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_packn_rvv(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int packn = csrr_vlenb() / 4; + const word_type vl = vsetvl_e32m1(packn); + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[6][6] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + // NOTE c99 variable length array + float tmp[6][6][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 4) + (j * 4) * packn; + + for (int m = 0; m < 6; m++) + { + vfloat32m1_t _r00 = vle32_v_f32m1(r0, vl); + vfloat32m1_t _r01 = vle32_v_f32m1(r0 + packn, vl); + vfloat32m1_t _r02 = vle32_v_f32m1(r0 + packn * 2, vl); + vfloat32m1_t _r03 = vle32_v_f32m1(r0 + packn * 3, vl); + vfloat32m1_t _r04 = vle32_v_f32m1(r0 + packn * 4, vl); + vfloat32m1_t _r05 = vle32_v_f32m1(r0 + packn * 5, vl); + + vfloat32m1_t _tmp0m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_r04, 4.f, _r00, vl), -5.f, _r02, vl); + vfloat32m1_t _tmp1m = vfmacc_vf_f32m1(vfadd_vv_f32m1(_r04, _r03, vl), -4.f, vfadd_vv_f32m1(_r01, _r02, vl), vl); + vfloat32m1_t _tmp2m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r04, _r03, vl), 4.f, vfsub_vv_f32m1(_r01, _r02, vl), vl); + vfloat32m1_t _tmp3m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r04, _r02, vl), -2.f, vfsub_vv_f32m1(_r01, _r03, vl), vl); + vfloat32m1_t _tmp4m = vfmacc_vf_f32m1(vfsub_vv_f32m1(_r04, _r02, vl), 2.f, vfsub_vv_f32m1(_r01, _r03, vl), vl); + vfloat32m1_t _tmp5m = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_r05, 4.f, _r01, vl), -5.f, _r03, vl); + + vse32_v_f32m1(tmp[0][m], _tmp0m, vl); + vse32_v_f32m1(tmp[1][m], _tmp1m, vl); + vse32_v_f32m1(tmp[2][m], _tmp2m, vl); + vse32_v_f32m1(tmp[3][m], _tmp3m, vl); + vse32_v_f32m1(tmp[4][m], _tmp4m, vl); + vse32_v_f32m1(tmp[5][m], _tmp5m, vl); + + r0 += w * packn; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * packn; + float* r0_tm_1 = r0_tm_0 + tiles * packn; + float* r0_tm_2 = r0_tm_0 + tiles * packn * 2; + float* r0_tm_3 = r0_tm_0 + tiles * packn * 3; + float* r0_tm_4 = r0_tm_0 + tiles * packn * 4; + float* r0_tm_5 = r0_tm_0 + tiles * packn * 5; + + for (int m = 0; m < 6; m++) + { + vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); + vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); + vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); + vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); + vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); + vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); + + vfloat32m1_t _r0tm0 = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp04, 4.f, _tmp00, vl), -5.f, _tmp02, vl); + vfloat32m1_t _r0tm1 = vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp04, _tmp03, vl), -4.f, vfadd_vv_f32m1(_tmp01, _tmp02, vl), vl); + vfloat32m1_t _r0tm2 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp04, _tmp03, vl), 4.f, vfsub_vv_f32m1(_tmp01, _tmp02, vl), vl); + vfloat32m1_t _r0tm3 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp04, _tmp02, vl), -2.f, vfsub_vv_f32m1(_tmp01, _tmp03, vl), vl); + vfloat32m1_t _r0tm4 = vfmacc_vf_f32m1(vfsub_vv_f32m1(_tmp04, _tmp02, vl), 2.f, vfsub_vv_f32m1(_tmp01, _tmp03, vl), vl); + vfloat32m1_t _r0tm5 = vfmacc_vf_f32m1(vfmacc_vf_f32m1(_tmp05, 4.f, _tmp01, vl), -5.f, _tmp03, vl); + + vse32_v_f32m1(r0_tm_0, _r0tm0, vl); + vse32_v_f32m1(r0_tm_1, _r0tm1, vl); + vse32_v_f32m1(r0_tm_2, _r0tm2, vl); + vse32_v_f32m1(r0_tm_3, _r0tm3, vl); + vse32_v_f32m1(r0_tm_4, _r0tm4, vl); + vse32_v_f32m1(r0_tm_5, _r0tm5, vl); + + r0_tm_0 += tiles * packn * 6; + r0_tm_1 += tiles * packn * 6; + r0_tm_2 += tiles * packn * 6; + r0_tm_3 += tiles * packn * 6; + r0_tm_4 += tiles * packn * 6; + r0_tm_5 += tiles * packn * 6; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_packn_rvv(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int packn = csrr_vlenb() / 4; + const word_type vl = vsetvl_e32m1(packn); + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + vfloat32m1_t _bias0 = biasptr ? vle32_v_f32m1(biasptr + p * packn, vl) : vfmv_v_f_f32m1(0.f, vl); + + // NOTE variable length array + float tmp[4][6][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * packn; + const float* output0_tm_1 = output0_tm_0 + tiles * packn; + const float* output0_tm_2 = output0_tm_0 + tiles * packn * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * packn * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * packn * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * packn * 5; + + float* output0 = out0.row(i * 4) + (j * 4) * packn; + + for (int m = 0; m < 6; m++) + { + vfloat32m1_t _out0tm0 = vle32_v_f32m1(output0_tm_0, vl); + vfloat32m1_t _out0tm1 = vle32_v_f32m1(output0_tm_1, vl); + vfloat32m1_t _out0tm2 = vle32_v_f32m1(output0_tm_2, vl); + vfloat32m1_t _out0tm3 = vle32_v_f32m1(output0_tm_3, vl); + vfloat32m1_t _out0tm4 = vle32_v_f32m1(output0_tm_4, vl); + vfloat32m1_t _out0tm5 = vle32_v_f32m1(output0_tm_5, vl); + + vfloat32m1_t _tmp02a = vfadd_vv_f32m1(_out0tm1, _out0tm2, vl); + vfloat32m1_t _tmp13a = vfsub_vv_f32m1(_out0tm1, _out0tm2, vl); + + vfloat32m1_t _tmp02b = vfadd_vv_f32m1(_out0tm3, _out0tm4, vl); + vfloat32m1_t _tmp13b = vfsub_vv_f32m1(_out0tm3, _out0tm4, vl); + + vfloat32m1_t _tmp0m = vfadd_vv_f32m1(vfadd_vv_f32m1(_out0tm0, _tmp02a, vl), _tmp02b, vl); + vfloat32m1_t _tmp1m = vfmacc_vf_f32m1(_tmp13a, 2.f, _tmp13b, vl); + vfloat32m1_t _tmp2m = vfmacc_vf_f32m1(_tmp02a, 4.f, _tmp02b, vl); + vfloat32m1_t _tmp3m = vfmacc_vf_f32m1(vfadd_vv_f32m1(_out0tm5, _tmp13a, vl), 8.f, _tmp13b, vl); + + vse32_v_f32m1(tmp[0][m], _tmp0m, vl); + vse32_v_f32m1(tmp[1][m], _tmp1m, vl); + vse32_v_f32m1(tmp[2][m], _tmp2m, vl); + vse32_v_f32m1(tmp[3][m], _tmp3m, vl); + + output0_tm_0 += tiles * packn * 6; + output0_tm_1 += tiles * packn * 6; + output0_tm_2 += tiles * packn * 6; + output0_tm_3 += tiles * packn * 6; + output0_tm_4 += tiles * packn * 6; + output0_tm_5 += tiles * packn * 6; + } + + for (int m = 0; m < 4; m++) + { + vfloat32m1_t _tmp00 = vle32_v_f32m1(tmp[m][0], vl); + vfloat32m1_t _tmp01 = vle32_v_f32m1(tmp[m][1], vl); + vfloat32m1_t _tmp02 = vle32_v_f32m1(tmp[m][2], vl); + vfloat32m1_t _tmp03 = vle32_v_f32m1(tmp[m][3], vl); + vfloat32m1_t _tmp04 = vle32_v_f32m1(tmp[m][4], vl); + vfloat32m1_t _tmp05 = vle32_v_f32m1(tmp[m][5], vl); + + vfloat32m1_t _tmp02a = vfadd_vv_f32m1(_tmp01, _tmp02, vl); + vfloat32m1_t _tmp13a = vfsub_vv_f32m1(_tmp01, _tmp02, vl); + + vfloat32m1_t _tmp02b = vfadd_vv_f32m1(_tmp03, _tmp04, vl); + vfloat32m1_t _tmp13b = vfsub_vv_f32m1(_tmp03, _tmp04, vl); + + vfloat32m1_t _out00 = vfadd_vv_f32m1(_bias0, vfadd_vv_f32m1(vfadd_vv_f32m1(_tmp00, _tmp02a, vl), _tmp02b, vl), vl); + vfloat32m1_t _out01 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(_tmp13a, 2.f, _tmp13b, vl), vl); + vfloat32m1_t _out02 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(_tmp02a, 4.f, _tmp02b, vl), vl); + vfloat32m1_t _out03 = vfadd_vv_f32m1(_bias0, vfmacc_vf_f32m1(vfadd_vv_f32m1(_tmp05, _tmp13a, vl), 8.f, _tmp13b, vl), vl); + + vse32_v_f32m1(output0, _out00, vl); + vse32_v_f32m1(output0 + packn, _out01, vl); + vse32_v_f32m1(output0 + packn * 2, _out02, vl); + vse32_v_f32m1(output0 + packn * 3, _out03, vl); + + output0 += outw * packn; + } + } + } + } +} diff --git a/src/layer/riscv/convolution_winograd_transform_packn_fp16s.h b/src/layer/riscv/convolution_winograd_transform_packn_fp16s.h new file mode 100644 index 00000000000..f7671a28c8d --- /dev/null +++ b/src/layer/riscv/convolution_winograd_transform_packn_fp16s.h @@ -0,0 +1,551 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int packn = csrr_vlenb() / 2; + const word_type vl = vsetvl_e16m1(packn); + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + // NOTE c99 variable length array + __fp16 tmp[8][8][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* r0 = img0.row(i * 6) + (j * 6) * packn; + + for (int m = 0; m < 8; m++) + { + vfloat16m1_t _r00 = vle16_v_f16m1(r0, vl); + vfloat16m1_t _r01 = vle16_v_f16m1(r0 + packn, vl); + vfloat16m1_t _r02 = vle16_v_f16m1(r0 + packn * 2, vl); + vfloat16m1_t _r03 = vle16_v_f16m1(r0 + packn * 3, vl); + vfloat16m1_t _r04 = vle16_v_f16m1(r0 + packn * 4, vl); + vfloat16m1_t _r05 = vle16_v_f16m1(r0 + packn * 5, vl); + vfloat16m1_t _r06 = vle16_v_f16m1(r0 + packn * 6, vl); + vfloat16m1_t _r07 = vle16_v_f16m1(r0 + packn * 7, vl); + + vfloat16m1_t _tmp0m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r00, _r06, vl), 5.25f, vfsub_vv_f16m1(_r04, _r02, vl), vl); + vfloat16m1_t _tmp7m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r07, _r01, vl), 5.25f, vfsub_vv_f16m1(_r03, _r05, vl), vl); + vse16_v_f16m1(tmp[0][m], _tmp0m, vl); + vse16_v_f16m1(tmp[7][m], _tmp7m, vl); + + vfloat16m1_t _tmp12a = vfmacc_vf_f16m1(vfadd_vv_f16m1(_r02, _r06, vl), -4.25f, _r04, vl); + vfloat16m1_t _tmp12b = vfmacc_vf_f16m1(vfadd_vv_f16m1(_r01, _r05, vl), -4.25f, _r03, vl); + + vfloat16m1_t _tmp1m = vfadd_vv_f16m1(_tmp12a, _tmp12b, vl); + vfloat16m1_t _tmp2m = vfsub_vv_f16m1(_tmp12a, _tmp12b, vl); + vse16_v_f16m1(tmp[1][m], _tmp1m, vl); + vse16_v_f16m1(tmp[2][m], _tmp2m, vl); + + vfloat16m1_t _tmp34a = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_r06, 0.25f, _r02, vl), -1.25f, _r04, vl); + vfloat16m1_t _tmp34b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_r01, 0.5f, vl), -2.5f, _r03, vl), 2.f, _r05, vl); + + vfloat16m1_t _tmp3m = vfadd_vv_f16m1(_tmp34a, _tmp34b, vl); + vfloat16m1_t _tmp4m = vfsub_vv_f16m1(_tmp34a, _tmp34b, vl); + vse16_v_f16m1(tmp[3][m], _tmp3m, vl); + vse16_v_f16m1(tmp[4][m], _tmp4m, vl); + + vfloat16m1_t _tmp56a = vfmacc_vf_f16m1(_r06, 4.f, vfmacc_vf_f16m1(_r02, -1.25f, _r04, vl), vl); + vfloat16m1_t _tmp56b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_r01, 2.f, vl), -2.5f, _r03, vl), 0.5f, _r05, vl); + + vfloat16m1_t _tmp5m = vfadd_vv_f16m1(_tmp56a, _tmp56b, vl); + vfloat16m1_t _tmp6m = vfsub_vv_f16m1(_tmp56a, _tmp56b, vl); + vse16_v_f16m1(tmp[5][m], _tmp5m, vl); + vse16_v_f16m1(tmp[6][m], _tmp6m, vl); + + r0 += w * packn; + } + + __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tiles + j) * packn; + __fp16* r0_tm_1 = r0_tm_0 + tiles * packn; + __fp16* r0_tm_2 = r0_tm_0 + tiles * packn * 2; + __fp16* r0_tm_3 = r0_tm_0 + tiles * packn * 3; + __fp16* r0_tm_4 = r0_tm_0 + tiles * packn * 4; + __fp16* r0_tm_5 = r0_tm_0 + tiles * packn * 5; + __fp16* r0_tm_6 = r0_tm_0 + tiles * packn * 6; + __fp16* r0_tm_7 = r0_tm_0 + tiles * packn * 7; + + for (int m = 0; m < 8; m++) + { + vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); + vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); + vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); + vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); + vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); + vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); + vfloat16m1_t _tmp06 = vle16_v_f16m1(tmp[m][6], vl); + vfloat16m1_t _tmp07 = vle16_v_f16m1(tmp[m][7], vl); + + vfloat16m1_t _r0tm0 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp00, _tmp06, vl), 5.25f, vfsub_vv_f16m1(_tmp04, _tmp02, vl), vl); + vfloat16m1_t _r0tm7 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp07, _tmp01, vl), 5.25f, vfsub_vv_f16m1(_tmp03, _tmp05, vl), vl); + + vfloat16m1_t _tmp12a = vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp02, _tmp06, vl), -4.25f, _tmp04, vl); + vfloat16m1_t _tmp12b = vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp01, _tmp05, vl), -4.25f, _tmp03, vl); + + vfloat16m1_t _r0tm1 = vfadd_vv_f16m1(_tmp12a, _tmp12b, vl); + vfloat16m1_t _r0tm2 = vfsub_vv_f16m1(_tmp12a, _tmp12b, vl); + + vfloat16m1_t _tmp34a = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp06, 0.25f, _tmp02, vl), -1.25f, _tmp04, vl); + vfloat16m1_t _tmp34b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_tmp01, 0.5f, vl), -2.5f, _tmp03, vl), 2.f, _tmp05, vl); + + vfloat16m1_t _r0tm3 = vfadd_vv_f16m1(_tmp34a, _tmp34b, vl); + vfloat16m1_t _r0tm4 = vfsub_vv_f16m1(_tmp34a, _tmp34b, vl); + + vfloat16m1_t _tmp56a = vfmacc_vf_f16m1(_tmp06, 4.f, vfmacc_vf_f16m1(_tmp02, -1.25f, _tmp04, vl), vl); + vfloat16m1_t _tmp56b = vfmacc_vf_f16m1(vfmacc_vf_f16m1(vfmul_vf_f16m1(_tmp01, 2.f, vl), -2.5f, _tmp03, vl), 0.5f, _tmp05, vl); + + vfloat16m1_t _r0tm5 = vfadd_vv_f16m1(_tmp56a, _tmp56b, vl); + vfloat16m1_t _r0tm6 = vfsub_vv_f16m1(_tmp56a, _tmp56b, vl); + + vse16_v_f16m1(r0_tm_0, _r0tm0, vl); + vse16_v_f16m1(r0_tm_1, _r0tm1, vl); + vse16_v_f16m1(r0_tm_2, _r0tm2, vl); + vse16_v_f16m1(r0_tm_3, _r0tm3, vl); + vse16_v_f16m1(r0_tm_4, _r0tm4, vl); + vse16_v_f16m1(r0_tm_5, _r0tm5, vl); + vse16_v_f16m1(r0_tm_6, _r0tm6, vl); + vse16_v_f16m1(r0_tm_7, _r0tm7, vl); + + r0_tm_0 += tiles * packn * 8; + r0_tm_1 += tiles * packn * 8; + r0_tm_2 += tiles * packn * 8; + r0_tm_3 += tiles * packn * 8; + r0_tm_4 += tiles * packn * 8; + r0_tm_5 += tiles * packn * 8; + r0_tm_6 += tiles * packn * 8; + r0_tm_7 += tiles * packn * 8; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_packn_fp16sa_rvv(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int packn = csrr_vlenb() / 2; + const word_type vl = vsetvl_e16m1(packn); + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const __fp16* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + vfloat16m1_t _bias0 = biasptr ? vle16_v_f16m1(biasptr + p * packn, vl) : vfmv_v_f_f16m1(0.f, vl); + + // NOTE c99 variable length array + __fp16 tmp[6][8][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tiles + j) * packn; + const __fp16* output0_tm_1 = output0_tm_0 + tiles * packn; + const __fp16* output0_tm_2 = output0_tm_0 + tiles * packn * 2; + const __fp16* output0_tm_3 = output0_tm_0 + tiles * packn * 3; + const __fp16* output0_tm_4 = output0_tm_0 + tiles * packn * 4; + const __fp16* output0_tm_5 = output0_tm_0 + tiles * packn * 5; + const __fp16* output0_tm_6 = output0_tm_0 + tiles * packn * 6; + const __fp16* output0_tm_7 = output0_tm_0 + tiles * packn * 7; + + __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * packn; + + for (int m = 0; m < 8; m++) + { + vfloat16m1_t _out0tm0 = vle16_v_f16m1(output0_tm_0, vl); + vfloat16m1_t _out0tm1 = vle16_v_f16m1(output0_tm_1, vl); + vfloat16m1_t _out0tm2 = vle16_v_f16m1(output0_tm_2, vl); + vfloat16m1_t _out0tm3 = vle16_v_f16m1(output0_tm_3, vl); + vfloat16m1_t _out0tm4 = vle16_v_f16m1(output0_tm_4, vl); + vfloat16m1_t _out0tm5 = vle16_v_f16m1(output0_tm_5, vl); + vfloat16m1_t _out0tm6 = vle16_v_f16m1(output0_tm_6, vl); + vfloat16m1_t _out0tm7 = vle16_v_f16m1(output0_tm_7, vl); + + vfloat16m1_t _tmp024a = vfadd_vv_f16m1(_out0tm1, _out0tm2, vl); + vfloat16m1_t _tmp135a = vfsub_vv_f16m1(_out0tm1, _out0tm2, vl); + + vfloat16m1_t _tmp024b = vfadd_vv_f16m1(_out0tm3, _out0tm4, vl); + vfloat16m1_t _tmp135b = vfsub_vv_f16m1(_out0tm3, _out0tm4, vl); + + vfloat16m1_t _tmp024c = vfadd_vv_f16m1(_out0tm5, _out0tm6, vl); + vfloat16m1_t _tmp135c = vfsub_vv_f16m1(_out0tm5, _out0tm6, vl); + + vfloat16m1_t _tmp0m = vfadd_vv_f16m1(vfadd_vv_f16m1(_out0tm0, _tmp024a, vl), vfmacc_vf_f16m1(_tmp024b, 32.f, _tmp024c, vl), vl); + vfloat16m1_t _tmp2m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl); + vfloat16m1_t _tmp4m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl); + vse16_v_f16m1(tmp[0][m], _tmp0m, vl); + vse16_v_f16m1(tmp[2][m], _tmp2m, vl); + vse16_v_f16m1(tmp[4][m], _tmp4m, vl); + + vfloat16m1_t _tmp1m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl); + vfloat16m1_t _tmp3m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl); + vfloat16m1_t _tmp5m = vfadd_vv_f16m1(vfadd_vv_f16m1(_out0tm7, _tmp135a, vl), vfmacc_vf_f16m1(_tmp135c, 32.f, _tmp135b, vl), vl); + vse16_v_f16m1(tmp[1][m], _tmp1m, vl); + vse16_v_f16m1(tmp[3][m], _tmp3m, vl); + vse16_v_f16m1(tmp[5][m], _tmp5m, vl); + + output0_tm_0 += tiles * packn * 8; + output0_tm_1 += tiles * packn * 8; + output0_tm_2 += tiles * packn * 8; + output0_tm_3 += tiles * packn * 8; + output0_tm_4 += tiles * packn * 8; + output0_tm_5 += tiles * packn * 8; + output0_tm_6 += tiles * packn * 8; + output0_tm_7 += tiles * packn * 8; + } + + for (int m = 0; m < 6; m++) + { + vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); + vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); + vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); + vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); + vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); + vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); + vfloat16m1_t _tmp06 = vle16_v_f16m1(tmp[m][6], vl); + vfloat16m1_t _tmp07 = vle16_v_f16m1(tmp[m][7], vl); + + vfloat16m1_t _tmp024a = vfadd_vv_f16m1(_tmp01, _tmp02, vl); + vfloat16m1_t _tmp135a = vfsub_vv_f16m1(_tmp01, _tmp02, vl); + + vfloat16m1_t _tmp024b = vfadd_vv_f16m1(_tmp03, _tmp04, vl); + vfloat16m1_t _tmp135b = vfsub_vv_f16m1(_tmp03, _tmp04, vl); + + vfloat16m1_t _tmp024c = vfadd_vv_f16m1(_tmp05, _tmp06, vl); + vfloat16m1_t _tmp135c = vfsub_vv_f16m1(_tmp05, _tmp06, vl); + + vfloat16m1_t _out00 = vfadd_vv_f16m1(_bias0, vfadd_vv_f16m1(vfadd_vv_f16m1(_tmp00, _tmp024a, vl), vfmacc_vf_f16m1(_tmp024b, 32.f, _tmp024c, vl), vl), vl); + vfloat16m1_t _out02 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 4.f, _tmp024b, vl), 8.f, _tmp024c, vl), vl); + vfloat16m1_t _out04 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp024a, 16.f, _tmp024b, vl), 2.f, _tmp024c, vl), vl); + vse16_v_f16m1(output0, _out00, vl); + vse16_v_f16m1(output0 + packn * 2, _out02, vl); + vse16_v_f16m1(output0 + packn * 4, _out04, vl); + + vfloat16m1_t _out01 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 2.f, _tmp135b, vl), 16.f, _tmp135c, vl), vl); + vfloat16m1_t _out03 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp135a, 8.f, _tmp135b, vl), 4.f, _tmp135c, vl), vl); + vfloat16m1_t _out05 = vfadd_vv_f16m1(_bias0, vfadd_vv_f16m1(vfadd_vv_f16m1(_tmp07, _tmp135a, vl), vfmacc_vf_f16m1(_tmp135c, 32.f, _tmp135b, vl), vl), vl); + vse16_v_f16m1(output0 + packn, _out01, vl); + vse16_v_f16m1(output0 + packn * 3, _out03, vl); + vse16_v_f16m1(output0 + packn * 5, _out05, vl); + + output0 += outw * packn; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_packn_fp16sa_rvv(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int packn = csrr_vlenb() / 2; + const word_type vl = vsetvl_e16m1(packn); + + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[6][6] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + + // NOTE c99 variable length array + __fp16 tmp[6][6][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* r0 = img0.row(i * 4) + (j * 4) * packn; + + for (int m = 0; m < 6; m++) + { + vfloat16m1_t _r00 = vle16_v_f16m1(r0, vl); + vfloat16m1_t _r01 = vle16_v_f16m1(r0 + packn, vl); + vfloat16m1_t _r02 = vle16_v_f16m1(r0 + packn * 2, vl); + vfloat16m1_t _r03 = vle16_v_f16m1(r0 + packn * 3, vl); + vfloat16m1_t _r04 = vle16_v_f16m1(r0 + packn * 4, vl); + vfloat16m1_t _r05 = vle16_v_f16m1(r0 + packn * 5, vl); + + vfloat16m1_t _tmp0m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_r04, 4.f, _r00, vl), -5.f, _r02, vl); + vfloat16m1_t _tmp1m = vfmacc_vf_f16m1(vfadd_vv_f16m1(_r04, _r03, vl), -4.f, vfadd_vv_f16m1(_r01, _r02, vl), vl); + vfloat16m1_t _tmp2m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r04, _r03, vl), 4.f, vfsub_vv_f16m1(_r01, _r02, vl), vl); + vfloat16m1_t _tmp3m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r04, _r02, vl), -2.f, vfsub_vv_f16m1(_r01, _r03, vl), vl); + vfloat16m1_t _tmp4m = vfmacc_vf_f16m1(vfsub_vv_f16m1(_r04, _r02, vl), 2.f, vfsub_vv_f16m1(_r01, _r03, vl), vl); + vfloat16m1_t _tmp5m = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_r05, 4.f, _r01, vl), -5.f, _r03, vl); + + vse16_v_f16m1(tmp[0][m], _tmp0m, vl); + vse16_v_f16m1(tmp[1][m], _tmp1m, vl); + vse16_v_f16m1(tmp[2][m], _tmp2m, vl); + vse16_v_f16m1(tmp[3][m], _tmp3m, vl); + vse16_v_f16m1(tmp[4][m], _tmp4m, vl); + vse16_v_f16m1(tmp[5][m], _tmp5m, vl); + + r0 += w * packn; + } + + __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tiles + j) * packn; + __fp16* r0_tm_1 = r0_tm_0 + tiles * packn; + __fp16* r0_tm_2 = r0_tm_0 + tiles * packn * 2; + __fp16* r0_tm_3 = r0_tm_0 + tiles * packn * 3; + __fp16* r0_tm_4 = r0_tm_0 + tiles * packn * 4; + __fp16* r0_tm_5 = r0_tm_0 + tiles * packn * 5; + + for (int m = 0; m < 6; m++) + { + vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); + vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); + vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); + vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); + vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); + vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); + + vfloat16m1_t _r0tm0 = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp04, 4.f, _tmp00, vl), -5.f, _tmp02, vl); + vfloat16m1_t _r0tm1 = vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp04, _tmp03, vl), -4.f, vfadd_vv_f16m1(_tmp01, _tmp02, vl), vl); + vfloat16m1_t _r0tm2 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp04, _tmp03, vl), 4.f, vfsub_vv_f16m1(_tmp01, _tmp02, vl), vl); + vfloat16m1_t _r0tm3 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp04, _tmp02, vl), -2.f, vfsub_vv_f16m1(_tmp01, _tmp03, vl), vl); + vfloat16m1_t _r0tm4 = vfmacc_vf_f16m1(vfsub_vv_f16m1(_tmp04, _tmp02, vl), 2.f, vfsub_vv_f16m1(_tmp01, _tmp03, vl), vl); + vfloat16m1_t _r0tm5 = vfmacc_vf_f16m1(vfmacc_vf_f16m1(_tmp05, 4.f, _tmp01, vl), -5.f, _tmp03, vl); + + vse16_v_f16m1(r0_tm_0, _r0tm0, vl); + vse16_v_f16m1(r0_tm_1, _r0tm1, vl); + vse16_v_f16m1(r0_tm_2, _r0tm2, vl); + vse16_v_f16m1(r0_tm_3, _r0tm3, vl); + vse16_v_f16m1(r0_tm_4, _r0tm4, vl); + vse16_v_f16m1(r0_tm_5, _r0tm5, vl); + + r0_tm_0 += tiles * packn * 6; + r0_tm_1 += tiles * packn * 6; + r0_tm_2 += tiles * packn * 6; + r0_tm_3 += tiles * packn * 6; + r0_tm_4 += tiles * packn * 6; + r0_tm_5 += tiles * packn * 6; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_packn_fp16sa_rvv(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int packn = csrr_vlenb() / 2; + const word_type vl = vsetvl_e16m1(packn); + + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const __fp16* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + vfloat16m1_t _bias0 = biasptr ? vle16_v_f16m1(biasptr + p * packn, vl) : vfmv_v_f_f16m1(0.f, vl); + + // NOTE variable length array + __fp16 tmp[4][6][packn]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tiles + j) * packn; + const __fp16* output0_tm_1 = output0_tm_0 + tiles * packn; + const __fp16* output0_tm_2 = output0_tm_0 + tiles * packn * 2; + const __fp16* output0_tm_3 = output0_tm_0 + tiles * packn * 3; + const __fp16* output0_tm_4 = output0_tm_0 + tiles * packn * 4; + const __fp16* output0_tm_5 = output0_tm_0 + tiles * packn * 5; + + __fp16* output0 = out0.row<__fp16>(i * 4) + (j * 4) * packn; + + for (int m = 0; m < 6; m++) + { + vfloat16m1_t _out0tm0 = vle16_v_f16m1(output0_tm_0, vl); + vfloat16m1_t _out0tm1 = vle16_v_f16m1(output0_tm_1, vl); + vfloat16m1_t _out0tm2 = vle16_v_f16m1(output0_tm_2, vl); + vfloat16m1_t _out0tm3 = vle16_v_f16m1(output0_tm_3, vl); + vfloat16m1_t _out0tm4 = vle16_v_f16m1(output0_tm_4, vl); + vfloat16m1_t _out0tm5 = vle16_v_f16m1(output0_tm_5, vl); + + vfloat16m1_t _tmp02a = vfadd_vv_f16m1(_out0tm1, _out0tm2, vl); + vfloat16m1_t _tmp13a = vfsub_vv_f16m1(_out0tm1, _out0tm2, vl); + + vfloat16m1_t _tmp02b = vfadd_vv_f16m1(_out0tm3, _out0tm4, vl); + vfloat16m1_t _tmp13b = vfsub_vv_f16m1(_out0tm3, _out0tm4, vl); + + vfloat16m1_t _tmp0m = vfadd_vv_f16m1(vfadd_vv_f16m1(_out0tm0, _tmp02a, vl), _tmp02b, vl); + vfloat16m1_t _tmp1m = vfmacc_vf_f16m1(_tmp13a, 2.f, _tmp13b, vl); + vfloat16m1_t _tmp2m = vfmacc_vf_f16m1(_tmp02a, 4.f, _tmp02b, vl); + vfloat16m1_t _tmp3m = vfmacc_vf_f16m1(vfadd_vv_f16m1(_out0tm5, _tmp13a, vl), 8.f, _tmp13b, vl); + + vse16_v_f16m1(tmp[0][m], _tmp0m, vl); + vse16_v_f16m1(tmp[1][m], _tmp1m, vl); + vse16_v_f16m1(tmp[2][m], _tmp2m, vl); + vse16_v_f16m1(tmp[3][m], _tmp3m, vl); + + output0_tm_0 += tiles * packn * 6; + output0_tm_1 += tiles * packn * 6; + output0_tm_2 += tiles * packn * 6; + output0_tm_3 += tiles * packn * 6; + output0_tm_4 += tiles * packn * 6; + output0_tm_5 += tiles * packn * 6; + } + + for (int m = 0; m < 4; m++) + { + vfloat16m1_t _tmp00 = vle16_v_f16m1(tmp[m][0], vl); + vfloat16m1_t _tmp01 = vle16_v_f16m1(tmp[m][1], vl); + vfloat16m1_t _tmp02 = vle16_v_f16m1(tmp[m][2], vl); + vfloat16m1_t _tmp03 = vle16_v_f16m1(tmp[m][3], vl); + vfloat16m1_t _tmp04 = vle16_v_f16m1(tmp[m][4], vl); + vfloat16m1_t _tmp05 = vle16_v_f16m1(tmp[m][5], vl); + + vfloat16m1_t _tmp02a = vfadd_vv_f16m1(_tmp01, _tmp02, vl); + vfloat16m1_t _tmp13a = vfsub_vv_f16m1(_tmp01, _tmp02, vl); + + vfloat16m1_t _tmp02b = vfadd_vv_f16m1(_tmp03, _tmp04, vl); + vfloat16m1_t _tmp13b = vfsub_vv_f16m1(_tmp03, _tmp04, vl); + + vfloat16m1_t _out00 = vfadd_vv_f16m1(_bias0, vfadd_vv_f16m1(vfadd_vv_f16m1(_tmp00, _tmp02a, vl), _tmp02b, vl), vl); + vfloat16m1_t _out01 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(_tmp13a, 2.f, _tmp13b, vl), vl); + vfloat16m1_t _out02 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(_tmp02a, 4.f, _tmp02b, vl), vl); + vfloat16m1_t _out03 = vfadd_vv_f16m1(_bias0, vfmacc_vf_f16m1(vfadd_vv_f16m1(_tmp05, _tmp13a, vl), 8.f, _tmp13b, vl), vl); + + vse16_v_f16m1(output0, _out00, vl); + vse16_v_f16m1(output0 + packn, _out01, vl); + vse16_v_f16m1(output0 + packn * 2, _out02, vl); + vse16_v_f16m1(output0 + packn * 3, _out03, vl); + + output0 += outw * packn; + } + } + } + } +} diff --git a/src/layer/x86/convolution_3x3_pack16.h b/src/layer/x86/convolution_3x3_pack16.h index 2722aac7b4f..f65fae3f6b7 100644 --- a/src/layer/x86/convolution_3x3_pack16.h +++ b/src/layer/x86/convolution_3x3_pack16.h @@ -93,7 +93,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack16_avx512(const Mat& kerne } } -static void conv3x3s1_winograd64_pack16_avx512(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack16_avx512(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -114,174 +114,19 @@ static void conv3x3s1_winograd64_pack16_avx512(const Mat& bottom_blob, Mat& top_ h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[8][8][16]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 16; - - for (int m = 0; m < 8; m++) - { - __m512 _r00 = _mm512_load_ps(r0); - __m512 _r01 = _mm512_load_ps(r0 + 16); - __m512 _r02 = _mm512_load_ps(r0 + 16 * 2); - __m512 _r03 = _mm512_load_ps(r0 + 16 * 3); - __m512 _r04 = _mm512_load_ps(r0 + 16 * 4); - __m512 _r05 = _mm512_load_ps(r0 + 16 * 5); - __m512 _r06 = _mm512_load_ps(r0 + 16 * 6); - __m512 _r07 = _mm512_load_ps(r0 + 16 * 7); - - __m512 _tmp0m = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_r04, _r02), _mm512_sub_ps(_r00, _r06)); - __m512 _tmp7m = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_r03, _r05), _mm512_sub_ps(_r07, _r01)); - _mm512_store_ps(tmp[0][m], _tmp0m); - _mm512_store_ps(tmp[7][m], _tmp7m); - - __m512 _tmp12a = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _r04, _mm512_add_ps(_r02, _r06)); - __m512 _tmp12b = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _r03, _mm512_add_ps(_r01, _r05)); - - __m512 _tmp1m = _mm512_add_ps(_tmp12a, _tmp12b); - __m512 _tmp2m = _mm512_sub_ps(_tmp12a, _tmp12b); - _mm512_store_ps(tmp[1][m], _tmp1m); - _mm512_store_ps(tmp[2][m], _tmp2m); - - __m512 _tmp34a = _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _r04, _mm512_fmadd_ps(_mm512_set1_ps(0.25f), _r02, _r06)); - __m512 _tmp34b = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _r05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _r03, _mm512_mul_ps(_r01, _mm512_set1_ps(0.5f)))); - - __m512 _tmp3m = _mm512_add_ps(_tmp34a, _tmp34b); - __m512 _tmp4m = _mm512_sub_ps(_tmp34a, _tmp34b); - _mm512_store_ps(tmp[3][m], _tmp3m); - _mm512_store_ps(tmp[4][m], _tmp4m); - - __m512 _tmp56a = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _r04, _r02), _r06); - __m512 _tmp56b = _mm512_fmadd_ps(_mm512_set1_ps(0.5f), _r05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _r03, _mm512_mul_ps(_r01, _mm512_set1_ps(2.f)))); - - __m512 _tmp5m = _mm512_add_ps(_tmp56a, _tmp56b); - __m512 _tmp6m = _mm512_sub_ps(_tmp56a, _tmp56b); - _mm512_store_ps(tmp[5][m], _tmp5m); - _mm512_store_ps(tmp[6][m], _tmp6m); - - r0 += w * 16; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 16; - float* r0_tm_1 = r0_tm_0 + tiles * 16; - float* r0_tm_2 = r0_tm_0 + tiles * 16 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 16 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 16 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 16 * 5; - float* r0_tm_6 = r0_tm_0 + tiles * 16 * 6; - float* r0_tm_7 = r0_tm_0 + tiles * 16 * 7; - - for (int m = 0; m < 8; m++) - { - __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); - __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); - __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); - __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); - __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); - __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); - __m512 _tmp06 = _mm512_load_ps(tmp[m][6]); - __m512 _tmp07 = _mm512_load_ps(tmp[m][7]); - - __m512 _r0tm0 = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_tmp04, _tmp02), _mm512_sub_ps(_tmp00, _tmp06)); - __m512 _r0tm7 = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_tmp03, _tmp05), _mm512_sub_ps(_tmp07, _tmp01)); - - __m512 _tmp12a = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _tmp04, _mm512_add_ps(_tmp02, _tmp06)); - __m512 _tmp12b = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _tmp03, _mm512_add_ps(_tmp01, _tmp05)); - - __m512 _r0tm1 = _mm512_add_ps(_tmp12a, _tmp12b); - __m512 _r0tm2 = _mm512_sub_ps(_tmp12a, _tmp12b); - - __m512 _tmp34a = _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _tmp04, _mm512_fmadd_ps(_mm512_set1_ps(0.25f), _tmp02, _tmp06)); - __m512 _tmp34b = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _tmp03, _mm512_mul_ps(_tmp01, _mm512_set1_ps(0.5f)))); - - __m512 _r0tm3 = _mm512_add_ps(_tmp34a, _tmp34b); - __m512 _r0tm4 = _mm512_sub_ps(_tmp34a, _tmp34b); - - __m512 _tmp56a = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _tmp04, _tmp02), _tmp06); - __m512 _tmp56b = _mm512_fmadd_ps(_mm512_set1_ps(0.5f), _tmp05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _tmp03, _mm512_mul_ps(_tmp01, _mm512_set1_ps(2.f)))); - - __m512 _r0tm5 = _mm512_add_ps(_tmp56a, _tmp56b); - __m512 _r0tm6 = _mm512_sub_ps(_tmp56a, _tmp56b); - - _mm512_store_ps(r0_tm_0, _r0tm0); - _mm512_store_ps(r0_tm_1, _r0tm1); - _mm512_store_ps(r0_tm_2, _r0tm2); - _mm512_store_ps(r0_tm_3, _r0tm3); - _mm512_store_ps(r0_tm_4, _r0tm4); - _mm512_store_ps(r0_tm_5, _r0tm5); - _mm512_store_ps(r0_tm_6, _r0tm6); - _mm512_store_ps(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 128; - r0_tm_1 += tiles * 128; - r0_tm_2 += tiles * 128; - r0_tm_3 += tiles * 128; - r0_tm_4 += tiles * 128; - r0_tm_5 += tiles * 128; - r0_tm_6 += tiles * 128; - r0_tm_7 += tiles * 128; - } - } - } - } + conv3x3s1_winograd64_transform_input_pack16_avx512(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input + // BEGIN dot Mat top_blob_tm; { @@ -809,143 +654,7 @@ static void conv3x3s1_winograd64_pack16_avx512(const Mat& bottom_blob, Mat& top_ top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - __m512 _bias0 = bias ? _mm512_loadu_ps((const float*)bias + p * 16) : _mm512_setzero_ps(); - -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[6][8][16]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 16; - const float* output0_tm_1 = output0_tm_0 + tiles * 16; - const float* output0_tm_2 = output0_tm_0 + tiles * 16 * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 16 * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 16 * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 16 * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * 16 * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * 16 * 7; - - float* output0 = out0.row(i * 6) + (j * 6) * 16; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - __m512 _out0tm0 = _mm512_load_ps(output0_tm_0); - __m512 _out0tm1 = _mm512_load_ps(output0_tm_1); - __m512 _out0tm2 = _mm512_load_ps(output0_tm_2); - __m512 _out0tm3 = _mm512_load_ps(output0_tm_3); - __m512 _out0tm4 = _mm512_load_ps(output0_tm_4); - __m512 _out0tm5 = _mm512_load_ps(output0_tm_5); - __m512 _out0tm6 = _mm512_load_ps(output0_tm_6); - __m512 _out0tm7 = _mm512_load_ps(output0_tm_7); - - __m512 _tmp024a = _mm512_add_ps(_out0tm1, _out0tm2); - __m512 _tmp135a = _mm512_sub_ps(_out0tm1, _out0tm2); - - __m512 _tmp024b = _mm512_add_ps(_out0tm3, _out0tm4); - __m512 _tmp135b = _mm512_sub_ps(_out0tm3, _out0tm4); - - __m512 _tmp024c = _mm512_add_ps(_out0tm5, _out0tm6); - __m512 _tmp135c = _mm512_sub_ps(_out0tm5, _out0tm6); - - __m512 _tmp0m = _mm512_add_ps(_mm512_add_ps(_out0tm0, _tmp024a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp024c, _tmp024b)); - __m512 _tmp2m = _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp024b, _tmp024a)); - __m512 _tmp4m = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp024b, _tmp024a)); - _mm512_store_ps(tmp[0][m], _tmp0m); - _mm512_store_ps(tmp[2][m], _tmp2m); - _mm512_store_ps(tmp[4][m], _tmp4m); - - __m512 _tmp1m = _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp135b, _tmp135a)); - __m512 _tmp3m = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp135b, _tmp135a)); - __m512 _tmp5m = _mm512_add_ps(_mm512_add_ps(_out0tm7, _tmp135a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp135b, _tmp135c)); - _mm512_store_ps(tmp[1][m], _tmp1m); - _mm512_store_ps(tmp[3][m], _tmp3m); - _mm512_store_ps(tmp[5][m], _tmp5m); - - output0_tm_0 += tiles * 128; - output0_tm_1 += tiles * 128; - output0_tm_2 += tiles * 128; - output0_tm_3 += tiles * 128; - output0_tm_4 += tiles * 128; - output0_tm_5 += tiles * 128; - output0_tm_6 += tiles * 128; - output0_tm_7 += tiles * 128; - } - - for (int m = 0; m < 6; m++) - { - __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); - __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); - __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); - __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); - __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); - __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); - __m512 _tmp06 = _mm512_load_ps(tmp[m][6]); - __m512 _tmp07 = _mm512_load_ps(tmp[m][7]); - - __m512 _tmp024a = _mm512_add_ps(_tmp01, _tmp02); - __m512 _tmp135a = _mm512_sub_ps(_tmp01, _tmp02); - - __m512 _tmp024b = _mm512_add_ps(_tmp03, _tmp04); - __m512 _tmp135b = _mm512_sub_ps(_tmp03, _tmp04); - - __m512 _tmp024c = _mm512_add_ps(_tmp05, _tmp06); - __m512 _tmp135c = _mm512_sub_ps(_tmp05, _tmp06); - - __m512 _out00 = _mm512_add_ps(_bias0, _mm512_add_ps(_mm512_add_ps(_tmp00, _tmp024a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp024c, _tmp024b))); - __m512 _out02 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp024b, _tmp024a))); - __m512 _out04 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp024b, _tmp024a))); - _mm512_store_ps(output0, _out00); - _mm512_store_ps(output0 + 32, _out02); - _mm512_store_ps(output0 + 64, _out04); - - __m512 _out01 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp135b, _tmp135a))); - __m512 _out03 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp135b, _tmp135a))); - __m512 _out05 = _mm512_add_ps(_bias0, _mm512_add_ps(_mm512_add_ps(_tmp07, _tmp135a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp135b, _tmp135c))); - _mm512_store_ps(output0 + 16, _out01); - _mm512_store_ps(output0 + 48, _out03); - _mm512_store_ps(output0 + 80, _out05); - - output0 += outw * 16; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack16_avx512(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -1031,7 +740,7 @@ static void conv3x3s1_winograd42_transform_kernel_pack16_avx512(const Mat& kerne } } -static void conv3x3s1_winograd42_pack16_avx512(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_pack16_avx512(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -1053,120 +762,15 @@ static void conv3x3s1_winograd42_pack16_avx512(const Mat& bottom_blob, Mat& top_ h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[6][6][16]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const float* r0 = img0.row(i * 4) + (j * 4) * 16; - - for (int m = 0; m < 6; m++) - { - __m512 _r00 = _mm512_load_ps(r0); - __m512 _r01 = _mm512_load_ps(r0 + 16); - __m512 _r02 = _mm512_load_ps(r0 + 16 * 2); - __m512 _r03 = _mm512_load_ps(r0 + 16 * 3); - __m512 _r04 = _mm512_load_ps(r0 + 16 * 4); - __m512 _r05 = _mm512_load_ps(r0 + 16 * 5); - - __m512 _tmp0m = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _r02, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _r00, _r04)); - __m512 _tmp1m = _mm512_fmadd_ps(_mm512_set1_ps(-4.f), _mm512_add_ps(_r01, _r02), _mm512_add_ps(_r04, _r03)); - __m512 _tmp2m = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_sub_ps(_r01, _r02), _mm512_sub_ps(_r04, _r03)); - __m512 _tmp3m = _mm512_fmadd_ps(_mm512_set1_ps(-2.f), _mm512_sub_ps(_r01, _r03), _mm512_sub_ps(_r04, _r02)); - __m512 _tmp4m = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _mm512_sub_ps(_r01, _r03), _mm512_sub_ps(_r04, _r02)); - __m512 _tmp5m = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _r03, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _r01, _r05)); - - _mm512_store_ps(tmp[0][m], _tmp0m); - _mm512_store_ps(tmp[1][m], _tmp1m); - _mm512_store_ps(tmp[2][m], _tmp2m); - _mm512_store_ps(tmp[3][m], _tmp3m); - _mm512_store_ps(tmp[4][m], _tmp4m); - _mm512_store_ps(tmp[5][m], _tmp5m); - - r0 += w * 16; - } + int w_tiles = outw / 4; + int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * 16; - float* r0_tm_1 = r0_tm_0 + tiles * 16; - float* r0_tm_2 = r0_tm_0 + tiles * 16 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 16 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 16 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 16 * 5; - - for (int m = 0; m < 6; m++) - { - __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); - __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); - __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); - __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); - __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); - __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); - - __m512 _r0tm0 = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _tmp02, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp00, _tmp04)); - __m512 _r0tm1 = _mm512_fmadd_ps(_mm512_set1_ps(-4.f), _mm512_add_ps(_tmp01, _tmp02), _mm512_add_ps(_tmp04, _tmp03)); - __m512 _r0tm2 = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_sub_ps(_tmp01, _tmp02), _mm512_sub_ps(_tmp04, _tmp03)); - __m512 _r0tm3 = _mm512_fmadd_ps(_mm512_set1_ps(-2.f), _mm512_sub_ps(_tmp01, _tmp03), _mm512_sub_ps(_tmp04, _tmp02)); - __m512 _r0tm4 = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _mm512_sub_ps(_tmp01, _tmp03), _mm512_sub_ps(_tmp04, _tmp02)); - __m512 _r0tm5 = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _tmp03, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp01, _tmp05)); - - _mm512_store_ps(r0_tm_0, _r0tm0); - _mm512_store_ps(r0_tm_1, _r0tm1); - _mm512_store_ps(r0_tm_2, _r0tm2); - _mm512_store_ps(r0_tm_3, _r0tm3); - _mm512_store_ps(r0_tm_4, _r0tm4); - _mm512_store_ps(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 96; - r0_tm_1 += tiles * 96; - r0_tm_2 += tiles * 96; - r0_tm_3 += tiles * 96; - r0_tm_4 += tiles * 96; - r0_tm_5 += tiles * 96; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_pack16_avx512(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1697,118 +1301,7 @@ static void conv3x3s1_winograd42_pack16_avx512(const Mat& bottom_blob, Mat& top_ top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - __m512 _bias0 = bias ? _mm512_loadu_ps((const float*)bias + p * 16) : _mm512_setzero_ps(); - -#ifdef _MSC_VER - __declspec(align(64)) -#else - __attribute__((aligned(64))) -#endif - float tmp[4][6][16]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * 16; - const float* output0_tm_1 = output0_tm_0 + tiles * 16; - const float* output0_tm_2 = output0_tm_0 + tiles * 16 * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 16 * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 16 * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 16 * 5; - - float* output0 = out0.row(i * 4) + (j * 4) * 16; - - // TODO msa optimize - for (int m = 0; m < 6; m++) - { - __m512 _out0tm0 = _mm512_load_ps(output0_tm_0); - __m512 _out0tm1 = _mm512_load_ps(output0_tm_1); - __m512 _out0tm2 = _mm512_load_ps(output0_tm_2); - __m512 _out0tm3 = _mm512_load_ps(output0_tm_3); - __m512 _out0tm4 = _mm512_load_ps(output0_tm_4); - __m512 _out0tm5 = _mm512_load_ps(output0_tm_5); - - __m512 _tmp02a = _mm512_add_ps(_out0tm1, _out0tm2); - __m512 _tmp13a = _mm512_sub_ps(_out0tm1, _out0tm2); - - __m512 _tmp02b = _mm512_add_ps(_out0tm3, _out0tm4); - __m512 _tmp13b = _mm512_sub_ps(_out0tm3, _out0tm4); - - __m512 _tmp0m = _mm512_add_ps(_mm512_add_ps(_out0tm0, _tmp02a), _tmp02b); - __m512 _tmp1m = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp13b, _tmp13a); - __m512 _tmp2m = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp02b, _tmp02a); - __m512 _tmp3m = _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp13b, _mm512_add_ps(_out0tm5, _tmp13a)); - - _mm512_store_ps(tmp[0][m], _tmp0m); - _mm512_store_ps(tmp[1][m], _tmp1m); - _mm512_store_ps(tmp[2][m], _tmp2m); - _mm512_store_ps(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 96; - output0_tm_1 += tiles * 96; - output0_tm_2 += tiles * 96; - output0_tm_3 += tiles * 96; - output0_tm_4 += tiles * 96; - output0_tm_5 += tiles * 96; - } - - for (int m = 0; m < 4; m++) - { - __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); - __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); - __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); - __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); - __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); - __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); - - __m512 _tmp02a = _mm512_add_ps(_tmp01, _tmp02); - __m512 _tmp13a = _mm512_sub_ps(_tmp01, _tmp02); - - __m512 _tmp02b = _mm512_add_ps(_tmp03, _tmp04); - __m512 _tmp13b = _mm512_sub_ps(_tmp03, _tmp04); - - __m512 _out00 = _mm512_add_ps(_bias0, _mm512_add_ps(_mm512_add_ps(_tmp00, _tmp02a), _tmp02b)); - __m512 _out01 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp13b, _tmp13a)); - __m512 _out02 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp02b, _tmp02a)); - __m512 _out03 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp13b, _mm512_add_ps(_tmp05, _tmp13a))); - - _mm512_store_ps(output0, _out00); - _mm512_store_ps(output0 + 16, _out01); - _mm512_store_ps(output0 + 16 * 2, _out02); - _mm512_store_ps(output0 + 16 * 3, _out03); - - output0 += outw * 16; - } - } - } - } + conv3x3s1_winograd42_transform_output_pack16_avx512(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/x86/convolution_3x3_pack4.h b/src/layer/x86/convolution_3x3_pack4.h index 96922c42d3e..4a21cac4a34 100644 --- a/src/layer/x86/convolution_3x3_pack4.h +++ b/src/layer/x86/convolution_3x3_pack4.h @@ -93,7 +93,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack4_sse(const Mat& kernel, M } } -static void conv3x3s1_winograd64_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -115,182 +115,15 @@ static void conv3x3s1_winograd64_pack4_sse(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + int tiles = w_tiles * h_tiles; - // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - bottom_blob_tm.create(tiles, 64, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[8][8][4]; - - __m128 _v5_25 = _mm_set1_ps(5.25f); - __m128 _vm4_25 = _mm_set1_ps(-4.25f); - __m128 _vm1_25 = _mm_set1_ps(-1.25f); - __m128 _v0_25 = _mm_set1_ps(0.25f); - __m128 _vm2_5 = _mm_set1_ps(-2.5f); - __m128 _v0_5 = _mm_set1_ps(0.5f); - __m128 _v2 = _mm_set1_ps(2.f); - __m128 _v4 = _mm_set1_ps(4.f); - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - __m128 _r00 = _mm_load_ps(r0); - __m128 _r01 = _mm_load_ps(r0 + 4); - __m128 _r02 = _mm_load_ps(r0 + 4 * 2); - __m128 _r03 = _mm_load_ps(r0 + 4 * 3); - __m128 _r04 = _mm_load_ps(r0 + 4 * 4); - __m128 _r05 = _mm_load_ps(r0 + 4 * 5); - __m128 _r06 = _mm_load_ps(r0 + 4 * 6); - __m128 _r07 = _mm_load_ps(r0 + 4 * 7); - - __m128 _tmp0m = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_r04, _r02), _mm_sub_ps(_r00, _r06)); - __m128 _tmp7m = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_r03, _r05), _mm_sub_ps(_r07, _r01)); - _mm_store_ps(tmp[0][m], _tmp0m); - _mm_store_ps(tmp[7][m], _tmp7m); - - __m128 _tmp12a = _mm_comp_fmadd_ps(_vm4_25, _r04, _mm_add_ps(_r02, _r06)); - __m128 _tmp12b = _mm_comp_fmadd_ps(_vm4_25, _r03, _mm_add_ps(_r01, _r05)); - - __m128 _tmp1m = _mm_add_ps(_tmp12a, _tmp12b); - __m128 _tmp2m = _mm_sub_ps(_tmp12a, _tmp12b); - _mm_store_ps(tmp[1][m], _tmp1m); - _mm_store_ps(tmp[2][m], _tmp2m); - - __m128 _tmp34a = _mm_comp_fmadd_ps(_vm1_25, _r04, _mm_comp_fmadd_ps(_v0_25, _r02, _r06)); - __m128 _tmp34b = _mm_comp_fmadd_ps(_v2, _r05, _mm_comp_fmadd_ps(_vm2_5, _r03, _mm_mul_ps(_r01, _v0_5))); - - __m128 _tmp3m = _mm_add_ps(_tmp34a, _tmp34b); - __m128 _tmp4m = _mm_sub_ps(_tmp34a, _tmp34b); - _mm_store_ps(tmp[3][m], _tmp3m); - _mm_store_ps(tmp[4][m], _tmp4m); - - __m128 _tmp56a = _mm_comp_fmadd_ps(_v4, _mm_comp_fmadd_ps(_vm1_25, _r04, _r02), _r06); - __m128 _tmp56b = _mm_comp_fmadd_ps(_v0_5, _r05, _mm_comp_fmadd_ps(_vm2_5, _r03, _mm_mul_ps(_r01, _v2))); - - __m128 _tmp5m = _mm_add_ps(_tmp56a, _tmp56b); - __m128 _tmp6m = _mm_sub_ps(_tmp56a, _tmp56b); - _mm_store_ps(tmp[5][m], _tmp5m); - _mm_store_ps(tmp[6][m], _tmp6m); - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; - float* r0_tm_6 = r0_tm_0 + tiles * 4 * 6; - float* r0_tm_7 = r0_tm_0 + tiles * 4 * 7; - - for (int m = 0; m < 8; m++) - { - __m128 _tmp00 = _mm_load_ps(tmp[m][0]); - __m128 _tmp01 = _mm_load_ps(tmp[m][1]); - __m128 _tmp02 = _mm_load_ps(tmp[m][2]); - __m128 _tmp03 = _mm_load_ps(tmp[m][3]); - __m128 _tmp04 = _mm_load_ps(tmp[m][4]); - __m128 _tmp05 = _mm_load_ps(tmp[m][5]); - __m128 _tmp06 = _mm_load_ps(tmp[m][6]); - __m128 _tmp07 = _mm_load_ps(tmp[m][7]); - - __m128 _r0tm0 = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_tmp04, _tmp02), _mm_sub_ps(_tmp00, _tmp06)); - __m128 _r0tm7 = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_tmp03, _tmp05), _mm_sub_ps(_tmp07, _tmp01)); - - __m128 _tmp12a = _mm_comp_fmadd_ps(_vm4_25, _tmp04, _mm_add_ps(_tmp02, _tmp06)); - __m128 _tmp12b = _mm_comp_fmadd_ps(_vm4_25, _tmp03, _mm_add_ps(_tmp01, _tmp05)); - - __m128 _r0tm1 = _mm_add_ps(_tmp12a, _tmp12b); - __m128 _r0tm2 = _mm_sub_ps(_tmp12a, _tmp12b); - - __m128 _tmp34a = _mm_comp_fmadd_ps(_vm1_25, _tmp04, _mm_comp_fmadd_ps(_v0_25, _tmp02, _tmp06)); - __m128 _tmp34b = _mm_comp_fmadd_ps(_v2, _tmp05, _mm_comp_fmadd_ps(_vm2_5, _tmp03, _mm_mul_ps(_tmp01, _v0_5))); - - __m128 _r0tm3 = _mm_add_ps(_tmp34a, _tmp34b); - __m128 _r0tm4 = _mm_sub_ps(_tmp34a, _tmp34b); - - __m128 _tmp56a = _mm_comp_fmadd_ps(_v4, _mm_comp_fmadd_ps(_vm1_25, _tmp04, _tmp02), _tmp06); - __m128 _tmp56b = _mm_comp_fmadd_ps(_v0_5, _tmp05, _mm_comp_fmadd_ps(_vm2_5, _tmp03, _mm_mul_ps(_tmp01, _v2))); - - __m128 _r0tm5 = _mm_add_ps(_tmp56a, _tmp56b); - __m128 _r0tm6 = _mm_sub_ps(_tmp56a, _tmp56b); - - _mm_store_ps(r0_tm_0, _r0tm0); - _mm_store_ps(r0_tm_1, _r0tm1); - _mm_store_ps(r0_tm_2, _r0tm2); - _mm_store_ps(r0_tm_3, _r0tm3); - _mm_store_ps(r0_tm_4, _r0tm4); - _mm_store_ps(r0_tm_5, _r0tm5); - _mm_store_ps(r0_tm_6, _r0tm6); - _mm_store_ps(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 4 * 8; - r0_tm_1 += tiles * 4 * 8; - r0_tm_2 += tiles * 4 * 8; - r0_tm_3 += tiles * 4 * 8; - r0_tm_4 += tiles * 4 * 8; - r0_tm_5 += tiles * 4 * 8; - r0_tm_6 += tiles * 4 * 8; - r0_tm_7 += tiles * 4 * 8; - } - } - } - } + bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd64_transform_input_pack4_sse(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -720,150 +553,7 @@ static void conv3x3s1_winograd64_pack4_sse(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - __m128 _bias0 = bias ? _mm_loadu_ps((const float*)bias + p * 4) : _mm_setzero_ps(); - -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[6][8][4]; - - __m128 _v32 = _mm_set1_ps(32.f); - __m128 _v16 = _mm_set1_ps(16.f); - __m128 _v8 = _mm_set1_ps(8.f); - __m128 _v4 = _mm_set1_ps(4.f); - __m128 _v2 = _mm_set1_ps(2.f); - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * 4 * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * 4 * 7; - - float* output0 = out0.row(i * 6) + (j * 6) * 4; - - // TODO msa optimize - for (int m = 0; m < 8; m++) - { - __m128 _out0tm0 = _mm_load_ps(output0_tm_0); - __m128 _out0tm1 = _mm_load_ps(output0_tm_1); - __m128 _out0tm2 = _mm_load_ps(output0_tm_2); - __m128 _out0tm3 = _mm_load_ps(output0_tm_3); - __m128 _out0tm4 = _mm_load_ps(output0_tm_4); - __m128 _out0tm5 = _mm_load_ps(output0_tm_5); - __m128 _out0tm6 = _mm_load_ps(output0_tm_6); - __m128 _out0tm7 = _mm_load_ps(output0_tm_7); - - __m128 _tmp024a = _mm_add_ps(_out0tm1, _out0tm2); - __m128 _tmp135a = _mm_sub_ps(_out0tm1, _out0tm2); - - __m128 _tmp024b = _mm_add_ps(_out0tm3, _out0tm4); - __m128 _tmp135b = _mm_sub_ps(_out0tm3, _out0tm4); - - __m128 _tmp024c = _mm_add_ps(_out0tm5, _out0tm6); - __m128 _tmp135c = _mm_sub_ps(_out0tm5, _out0tm6); - - __m128 _tmp0m = _mm_add_ps(_mm_add_ps(_out0tm0, _tmp024a), _mm_comp_fmadd_ps(_v32, _tmp024c, _tmp024b)); - __m128 _tmp2m = _mm_comp_fmadd_ps(_v8, _tmp024c, _mm_comp_fmadd_ps(_v4, _tmp024b, _tmp024a)); - __m128 _tmp4m = _mm_comp_fmadd_ps(_v2, _tmp024c, _mm_comp_fmadd_ps(_v16, _tmp024b, _tmp024a)); - _mm_store_ps(tmp[0][m], _tmp0m); - _mm_store_ps(tmp[2][m], _tmp2m); - _mm_store_ps(tmp[4][m], _tmp4m); - - __m128 _tmp1m = _mm_comp_fmadd_ps(_v16, _tmp135c, _mm_comp_fmadd_ps(_v2, _tmp135b, _tmp135a)); - __m128 _tmp3m = _mm_comp_fmadd_ps(_v4, _tmp135c, _mm_comp_fmadd_ps(_v8, _tmp135b, _tmp135a)); - __m128 _tmp5m = _mm_add_ps(_mm_add_ps(_out0tm7, _tmp135a), _mm_comp_fmadd_ps(_v32, _tmp135b, _tmp135c)); - _mm_store_ps(tmp[1][m], _tmp1m); - _mm_store_ps(tmp[3][m], _tmp3m); - _mm_store_ps(tmp[5][m], _tmp5m); - - output0_tm_0 += tiles * 4 * 8; - output0_tm_1 += tiles * 4 * 8; - output0_tm_2 += tiles * 4 * 8; - output0_tm_3 += tiles * 4 * 8; - output0_tm_4 += tiles * 4 * 8; - output0_tm_5 += tiles * 4 * 8; - output0_tm_6 += tiles * 4 * 8; - output0_tm_7 += tiles * 4 * 8; - } - - for (int m = 0; m < 6; m++) - { - __m128 _tmp00 = _mm_load_ps(tmp[m][0]); - __m128 _tmp01 = _mm_load_ps(tmp[m][1]); - __m128 _tmp02 = _mm_load_ps(tmp[m][2]); - __m128 _tmp03 = _mm_load_ps(tmp[m][3]); - __m128 _tmp04 = _mm_load_ps(tmp[m][4]); - __m128 _tmp05 = _mm_load_ps(tmp[m][5]); - __m128 _tmp06 = _mm_load_ps(tmp[m][6]); - __m128 _tmp07 = _mm_load_ps(tmp[m][7]); - - __m128 _tmp024a = _mm_add_ps(_tmp01, _tmp02); - __m128 _tmp135a = _mm_sub_ps(_tmp01, _tmp02); - - __m128 _tmp024b = _mm_add_ps(_tmp03, _tmp04); - __m128 _tmp135b = _mm_sub_ps(_tmp03, _tmp04); - - __m128 _tmp024c = _mm_add_ps(_tmp05, _tmp06); - __m128 _tmp135c = _mm_sub_ps(_tmp05, _tmp06); - - __m128 _out00 = _mm_add_ps(_bias0, _mm_add_ps(_mm_add_ps(_tmp00, _tmp024a), _mm_comp_fmadd_ps(_v32, _tmp024c, _tmp024b))); - __m128 _out02 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v8, _tmp024c, _mm_comp_fmadd_ps(_v4, _tmp024b, _tmp024a))); - __m128 _out04 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v2, _tmp024c, _mm_comp_fmadd_ps(_v16, _tmp024b, _tmp024a))); - _mm_store_ps(output0, _out00); - _mm_store_ps(output0 + 4 * 2, _out02); - _mm_store_ps(output0 + 4 * 4, _out04); - - __m128 _out01 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v16, _tmp135c, _mm_comp_fmadd_ps(_v2, _tmp135b, _tmp135a))); - __m128 _out03 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v4, _tmp135c, _mm_comp_fmadd_ps(_v8, _tmp135b, _tmp135a))); - __m128 _out05 = _mm_add_ps(_bias0, _mm_add_ps(_mm_add_ps(_tmp07, _tmp135a), _mm_comp_fmadd_ps(_v32, _tmp135b, _tmp135c))); - _mm_store_ps(output0 + 4, _out01); - _mm_store_ps(output0 + 4 * 3, _out03); - _mm_store_ps(output0 + 4 * 5, _out05); - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack4_sse(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -949,7 +639,7 @@ static void conv3x3s1_winograd42_transform_kernel_pack4_sse(const Mat& kernel, M } } -static void conv3x3s1_winograd42_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_pack4_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -971,126 +661,15 @@ static void conv3x3s1_winograd42_pack4_sse(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[6][6] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[6][6][4]; - - __m128 _vm5 = _mm_set1_ps(-5.f); - __m128 _vm4 = _mm_set1_ps(-4.f); - __m128 _v4 = _mm_set1_ps(4.f); - __m128 _vm2 = _mm_set1_ps(-2.f); - __m128 _v2 = _mm_set1_ps(2.f); - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const float* r0 = img0.row(i * 4) + (j * 4) * 4; - - for (int m = 0; m < 6; m++) - { - __m128 _r00 = _mm_load_ps(r0); - __m128 _r01 = _mm_load_ps(r0 + 4); - __m128 _r02 = _mm_load_ps(r0 + 4 * 2); - __m128 _r03 = _mm_load_ps(r0 + 4 * 3); - __m128 _r04 = _mm_load_ps(r0 + 4 * 4); - __m128 _r05 = _mm_load_ps(r0 + 4 * 5); - - __m128 _tmp0m = _mm_comp_fmadd_ps(_vm5, _r02, _mm_comp_fmadd_ps(_v4, _r00, _r04)); - __m128 _tmp1m = _mm_comp_fmadd_ps(_vm4, _mm_add_ps(_r01, _r02), _mm_add_ps(_r04, _r03)); - __m128 _tmp2m = _mm_comp_fmadd_ps(_v4, _mm_sub_ps(_r01, _r02), _mm_sub_ps(_r04, _r03)); - __m128 _tmp3m = _mm_comp_fmadd_ps(_vm2, _mm_sub_ps(_r01, _r03), _mm_sub_ps(_r04, _r02)); - __m128 _tmp4m = _mm_comp_fmadd_ps(_v2, _mm_sub_ps(_r01, _r03), _mm_sub_ps(_r04, _r02)); - __m128 _tmp5m = _mm_comp_fmadd_ps(_vm5, _r03, _mm_comp_fmadd_ps(_v4, _r01, _r05)); - - _mm_store_ps(tmp[0][m], _tmp0m); - _mm_store_ps(tmp[1][m], _tmp1m); - _mm_store_ps(tmp[2][m], _tmp2m); - _mm_store_ps(tmp[3][m], _tmp3m); - _mm_store_ps(tmp[4][m], _tmp4m); - _mm_store_ps(tmp[5][m], _tmp5m); - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; + int w_tiles = outw / 4; + int h_tiles = outh / 4; + int tiles = w_tiles * h_tiles; - for (int m = 0; m < 6; m++) - { - __m128 _tmp00 = _mm_load_ps(tmp[m][0]); - __m128 _tmp01 = _mm_load_ps(tmp[m][1]); - __m128 _tmp02 = _mm_load_ps(tmp[m][2]); - __m128 _tmp03 = _mm_load_ps(tmp[m][3]); - __m128 _tmp04 = _mm_load_ps(tmp[m][4]); - __m128 _tmp05 = _mm_load_ps(tmp[m][5]); - - __m128 _r0tm0 = _mm_comp_fmadd_ps(_vm5, _tmp02, _mm_comp_fmadd_ps(_v4, _tmp00, _tmp04)); - __m128 _r0tm1 = _mm_comp_fmadd_ps(_vm4, _mm_add_ps(_tmp01, _tmp02), _mm_add_ps(_tmp04, _tmp03)); - __m128 _r0tm2 = _mm_comp_fmadd_ps(_v4, _mm_sub_ps(_tmp01, _tmp02), _mm_sub_ps(_tmp04, _tmp03)); - __m128 _r0tm3 = _mm_comp_fmadd_ps(_vm2, _mm_sub_ps(_tmp01, _tmp03), _mm_sub_ps(_tmp04, _tmp02)); - __m128 _r0tm4 = _mm_comp_fmadd_ps(_v2, _mm_sub_ps(_tmp01, _tmp03), _mm_sub_ps(_tmp04, _tmp02)); - __m128 _r0tm5 = _mm_comp_fmadd_ps(_vm5, _tmp03, _mm_comp_fmadd_ps(_v4, _tmp01, _tmp05)); - - _mm_store_ps(r0_tm_0, _r0tm0); - _mm_store_ps(r0_tm_1, _r0tm1); - _mm_store_ps(r0_tm_2, _r0tm2); - _mm_store_ps(r0_tm_3, _r0tm3); - _mm_store_ps(r0_tm_4, _r0tm4); - _mm_store_ps(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 4 * 6; - r0_tm_1 += tiles * 4 * 6; - r0_tm_2 += tiles * 4 * 6; - r0_tm_3 += tiles * 4 * 6; - r0_tm_4 += tiles * 4 * 6; - r0_tm_5 += tiles * 4 * 6; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_pack4_sse(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -1520,122 +1099,7 @@ static void conv3x3s1_winograd42_pack4_sse(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - __m128 _bias0 = bias ? _mm_loadu_ps((const float*)bias + p * 4) : _mm_setzero_ps(); - -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[4][6][4]; - - __m128 _v2 = _mm_set1_ps(2.f); - __m128 _v4 = _mm_set1_ps(4.f); - __m128 _v8 = _mm_set1_ps(8.f); - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * 4; - const float* output0_tm_1 = output0_tm_0 + tiles * 4; - const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; - - float* output0 = out0.row(i * 4) + (j * 4) * 4; - - // TODO msa optimize - for (int m = 0; m < 6; m++) - { - __m128 _out0tm0 = _mm_load_ps(output0_tm_0); - __m128 _out0tm1 = _mm_load_ps(output0_tm_1); - __m128 _out0tm2 = _mm_load_ps(output0_tm_2); - __m128 _out0tm3 = _mm_load_ps(output0_tm_3); - __m128 _out0tm4 = _mm_load_ps(output0_tm_4); - __m128 _out0tm5 = _mm_load_ps(output0_tm_5); - - __m128 _tmp02a = _mm_add_ps(_out0tm1, _out0tm2); - __m128 _tmp13a = _mm_sub_ps(_out0tm1, _out0tm2); - - __m128 _tmp02b = _mm_add_ps(_out0tm3, _out0tm4); - __m128 _tmp13b = _mm_sub_ps(_out0tm3, _out0tm4); - - __m128 _tmp0m = _mm_add_ps(_mm_add_ps(_out0tm0, _tmp02a), _tmp02b); - __m128 _tmp1m = _mm_comp_fmadd_ps(_v2, _tmp13b, _tmp13a); - __m128 _tmp2m = _mm_comp_fmadd_ps(_v4, _tmp02b, _tmp02a); - __m128 _tmp3m = _mm_comp_fmadd_ps(_v8, _tmp13b, _mm_add_ps(_out0tm5, _tmp13a)); - - _mm_store_ps(tmp[0][m], _tmp0m); - _mm_store_ps(tmp[1][m], _tmp1m); - _mm_store_ps(tmp[2][m], _tmp2m); - _mm_store_ps(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 4 * 6; - output0_tm_1 += tiles * 4 * 6; - output0_tm_2 += tiles * 4 * 6; - output0_tm_3 += tiles * 4 * 6; - output0_tm_4 += tiles * 4 * 6; - output0_tm_5 += tiles * 4 * 6; - } - - for (int m = 0; m < 4; m++) - { - __m128 _tmp00 = _mm_load_ps(tmp[m][0]); - __m128 _tmp01 = _mm_load_ps(tmp[m][1]); - __m128 _tmp02 = _mm_load_ps(tmp[m][2]); - __m128 _tmp03 = _mm_load_ps(tmp[m][3]); - __m128 _tmp04 = _mm_load_ps(tmp[m][4]); - __m128 _tmp05 = _mm_load_ps(tmp[m][5]); - - __m128 _tmp02a = _mm_add_ps(_tmp01, _tmp02); - __m128 _tmp13a = _mm_sub_ps(_tmp01, _tmp02); - - __m128 _tmp02b = _mm_add_ps(_tmp03, _tmp04); - __m128 _tmp13b = _mm_sub_ps(_tmp03, _tmp04); - - __m128 _out00 = _mm_add_ps(_bias0, _mm_add_ps(_mm_add_ps(_tmp00, _tmp02a), _tmp02b)); - __m128 _out01 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v2, _tmp13b, _tmp13a)); - __m128 _out02 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v4, _tmp02b, _tmp02a)); - __m128 _out03 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v8, _tmp13b, _mm_add_ps(_tmp05, _tmp13a))); - - _mm_store_ps(output0, _out00); - _mm_store_ps(output0 + 4, _out01); - _mm_store_ps(output0 + 4 * 2, _out02); - _mm_store_ps(output0 + 4 * 3, _out03); - - output0 += outw * 4; - } - } - } - } + conv3x3s1_winograd42_transform_output_pack4_sse(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/x86/convolution_3x3_pack4to1.h b/src/layer/x86/convolution_3x3_pack4to1.h index 46cbb576f3b..62bf8eb9a7b 100644 --- a/src/layer/x86/convolution_3x3_pack4to1.h +++ b/src/layer/x86/convolution_3x3_pack4to1.h @@ -157,7 +157,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack4to1_sse(const Mat& kernel } } -static void conv3x3s1_winograd64_pack4to1_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack4to1_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -179,8 +179,6 @@ static void conv3x3s1_winograd64_pack4to1_sse(const Mat& bottom_blob, Mat& top_b h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { @@ -190,170 +188,7 @@ static void conv3x3s1_winograd64_pack4to1_sse(const Mat& bottom_blob, Mat& top_b const int tiles = w_tm / 8 * h_tm / 8; bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(16)) -#else - __attribute__((aligned(16))) -#endif - float tmp[8][8][4]; - - __m128 _v5_25 = _mm_set1_ps(5.25f); - __m128 _vm4_25 = _mm_set1_ps(-4.25f); - __m128 _vm1_25 = _mm_set1_ps(-1.25f); - __m128 _v0_25 = _mm_set1_ps(0.25f); - __m128 _vm2_5 = _mm_set1_ps(-2.5f); - __m128 _v0_5 = _mm_set1_ps(0.5f); - __m128 _v2 = _mm_set1_ps(2.f); - __m128 _v4 = _mm_set1_ps(4.f); - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 4; - - for (int m = 0; m < 8; m++) - { - __m128 _r00 = _mm_load_ps(r0); - __m128 _r01 = _mm_load_ps(r0 + 4); - __m128 _r02 = _mm_load_ps(r0 + 4 * 2); - __m128 _r03 = _mm_load_ps(r0 + 4 * 3); - __m128 _r04 = _mm_load_ps(r0 + 4 * 4); - __m128 _r05 = _mm_load_ps(r0 + 4 * 5); - __m128 _r06 = _mm_load_ps(r0 + 4 * 6); - __m128 _r07 = _mm_load_ps(r0 + 4 * 7); - - __m128 _tmp0m = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_r04, _r02), _mm_sub_ps(_r00, _r06)); - __m128 _tmp7m = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_r03, _r05), _mm_sub_ps(_r07, _r01)); - _mm_store_ps(tmp[0][m], _tmp0m); - _mm_store_ps(tmp[7][m], _tmp7m); - - __m128 _tmp12a = _mm_comp_fmadd_ps(_vm4_25, _r04, _mm_add_ps(_r02, _r06)); - __m128 _tmp12b = _mm_comp_fmadd_ps(_vm4_25, _r03, _mm_add_ps(_r01, _r05)); - - __m128 _tmp1m = _mm_add_ps(_tmp12a, _tmp12b); - __m128 _tmp2m = _mm_sub_ps(_tmp12a, _tmp12b); - _mm_store_ps(tmp[1][m], _tmp1m); - _mm_store_ps(tmp[2][m], _tmp2m); - - __m128 _tmp34a = _mm_comp_fmadd_ps(_vm1_25, _r04, _mm_comp_fmadd_ps(_v0_25, _r02, _r06)); - __m128 _tmp34b = _mm_comp_fmadd_ps(_v2, _r05, _mm_comp_fmadd_ps(_vm2_5, _r03, _mm_mul_ps(_r01, _v0_5))); - - __m128 _tmp3m = _mm_add_ps(_tmp34a, _tmp34b); - __m128 _tmp4m = _mm_sub_ps(_tmp34a, _tmp34b); - _mm_store_ps(tmp[3][m], _tmp3m); - _mm_store_ps(tmp[4][m], _tmp4m); - - __m128 _tmp56a = _mm_comp_fmadd_ps(_v4, _mm_comp_fmadd_ps(_vm1_25, _r04, _r02), _r06); - __m128 _tmp56b = _mm_comp_fmadd_ps(_v0_5, _r05, _mm_comp_fmadd_ps(_vm2_5, _r03, _mm_mul_ps(_r01, _v2))); - - __m128 _tmp5m = _mm_add_ps(_tmp56a, _tmp56b); - __m128 _tmp6m = _mm_sub_ps(_tmp56a, _tmp56b); - _mm_store_ps(tmp[5][m], _tmp5m); - _mm_store_ps(tmp[6][m], _tmp6m); - - r0 += w * 4; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4; - float* r0_tm_1 = r0_tm_0 + tiles * 4; - float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; - float* r0_tm_6 = r0_tm_0 + tiles * 4 * 6; - float* r0_tm_7 = r0_tm_0 + tiles * 4 * 7; - - for (int m = 0; m < 8; m++) - { - __m128 _tmp00 = _mm_load_ps(tmp[m][0]); - __m128 _tmp01 = _mm_load_ps(tmp[m][1]); - __m128 _tmp02 = _mm_load_ps(tmp[m][2]); - __m128 _tmp03 = _mm_load_ps(tmp[m][3]); - __m128 _tmp04 = _mm_load_ps(tmp[m][4]); - __m128 _tmp05 = _mm_load_ps(tmp[m][5]); - __m128 _tmp06 = _mm_load_ps(tmp[m][6]); - __m128 _tmp07 = _mm_load_ps(tmp[m][7]); - - __m128 _r0tm0 = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_tmp04, _tmp02), _mm_sub_ps(_tmp00, _tmp06)); - __m128 _r0tm7 = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_tmp03, _tmp05), _mm_sub_ps(_tmp07, _tmp01)); - - __m128 _tmp12a = _mm_comp_fmadd_ps(_vm4_25, _tmp04, _mm_add_ps(_tmp02, _tmp06)); - __m128 _tmp12b = _mm_comp_fmadd_ps(_vm4_25, _tmp03, _mm_add_ps(_tmp01, _tmp05)); - - __m128 _r0tm1 = _mm_add_ps(_tmp12a, _tmp12b); - __m128 _r0tm2 = _mm_sub_ps(_tmp12a, _tmp12b); - - __m128 _tmp34a = _mm_comp_fmadd_ps(_vm1_25, _tmp04, _mm_comp_fmadd_ps(_v0_25, _tmp02, _tmp06)); - __m128 _tmp34b = _mm_comp_fmadd_ps(_v2, _tmp05, _mm_comp_fmadd_ps(_vm2_5, _tmp03, _mm_mul_ps(_tmp01, _v0_5))); - - __m128 _r0tm3 = _mm_add_ps(_tmp34a, _tmp34b); - __m128 _r0tm4 = _mm_sub_ps(_tmp34a, _tmp34b); - - __m128 _tmp56a = _mm_comp_fmadd_ps(_v4, _mm_comp_fmadd_ps(_vm1_25, _tmp04, _tmp02), _tmp06); - __m128 _tmp56b = _mm_comp_fmadd_ps(_v0_5, _tmp05, _mm_comp_fmadd_ps(_vm2_5, _tmp03, _mm_mul_ps(_tmp01, _v2))); - - __m128 _r0tm5 = _mm_add_ps(_tmp56a, _tmp56b); - __m128 _r0tm6 = _mm_sub_ps(_tmp56a, _tmp56b); - - _mm_store_ps(r0_tm_0, _r0tm0); - _mm_store_ps(r0_tm_1, _r0tm1); - _mm_store_ps(r0_tm_2, _r0tm2); - _mm_store_ps(r0_tm_3, _r0tm3); - _mm_store_ps(r0_tm_4, _r0tm4); - _mm_store_ps(r0_tm_5, _r0tm5); - _mm_store_ps(r0_tm_6, _r0tm6); - _mm_store_ps(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 4 * 8; - r0_tm_1 += tiles * 4 * 8; - r0_tm_2 += tiles * 4 * 8; - r0_tm_3 += tiles * 4 * 8; - r0_tm_4 += tiles * 4 * 8; - r0_tm_5 += tiles * 4 * 8; - r0_tm_6 += tiles * 4 * 8; - r0_tm_7 += tiles * 4 * 8; - } - } - } - } + conv3x3s1_winograd64_transform_input_pack4_sse(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -727,110 +562,7 @@ static void conv3x3s1_winograd64_pack4to1_sse(const Mat& bottom_blob, Mat& top_b top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - const float bias0 = bias ? bias[p] : 0.f; - - float tmp[6][8]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, 4u, 1, opt.workspace_allocator); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 1; - const float* output0_tm_1 = output0_tm_0 + tiles * 1; - const float* output0_tm_2 = output0_tm_0 + tiles * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * 7; - - // TODO sse optimize - for (int m = 0; m < 8; m++) - { - float tmp024a = output0_tm_1[0] + output0_tm_2[0]; - float tmp135a = output0_tm_1[0] - output0_tm_2[0]; - - float tmp024b = output0_tm_3[0] + output0_tm_4[0]; - float tmp135b = output0_tm_3[0] - output0_tm_4[0]; - - float tmp024c = output0_tm_5[0] + output0_tm_6[0]; - float tmp135c = output0_tm_5[0] - output0_tm_6[0]; - - tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; - tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 8; - output0_tm_1 += tiles * 8; - output0_tm_2 += tiles * 8; - output0_tm_3 += tiles * 8; - output0_tm_4 += tiles * 8; - output0_tm_5 += tiles * 8; - output0_tm_6 += tiles * 8; - output0_tm_7 += tiles * 8; - } - - float* output0 = out0.row(i * 6) + j * 6; - - for (int m = 0; m < 6; m++) - { - const float* tmp0 = tmp[m]; - - float tmp024a = tmp0[1] + tmp0[2]; - float tmp135a = tmp0[1] - tmp0[2]; - - float tmp024b = tmp0[3] + tmp0[4]; - float tmp135b = tmp0[3] - tmp0[4]; - - float tmp024c = tmp0[5] + tmp0[6]; - float tmp135c = tmp0[5] - tmp0[6]; - - output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw; - } - } - } - } + conv3x3s1_winograd64_transform_output_sse(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/x86/convolution_3x3_pack8.h b/src/layer/x86/convolution_3x3_pack8.h index ba2e069f9ee..340cace4466 100644 --- a/src/layer/x86/convolution_3x3_pack8.h +++ b/src/layer/x86/convolution_3x3_pack8.h @@ -943,7 +943,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack8_avx(const Mat& kernel, M } } -static void conv3x3s1_winograd64_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -964,174 +964,19 @@ static void conv3x3s1_winograd64_pack8_avx(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[8][8][8]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 8; - - for (int m = 0; m < 8; m++) - { - __m256 _r00 = _mm256_load_ps(r0); - __m256 _r01 = _mm256_load_ps(r0 + 8); - __m256 _r02 = _mm256_load_ps(r0 + 16); - __m256 _r03 = _mm256_load_ps(r0 + 24); - __m256 _r04 = _mm256_load_ps(r0 + 32); - __m256 _r05 = _mm256_load_ps(r0 + 40); - __m256 _r06 = _mm256_load_ps(r0 + 48); - __m256 _r07 = _mm256_load_ps(r0 + 56); - - __m256 _tmp0m = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_r04, _r02), _mm256_sub_ps(_r00, _r06)); - __m256 _tmp7m = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_r03, _r05), _mm256_sub_ps(_r07, _r01)); - _mm256_store_ps(tmp[0][m], _tmp0m); - _mm256_store_ps(tmp[7][m], _tmp7m); - - __m256 _tmp12a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _r04, _mm256_add_ps(_r02, _r06)); - __m256 _tmp12b = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _r03, _mm256_add_ps(_r01, _r05)); - - __m256 _tmp1m = _mm256_add_ps(_tmp12a, _tmp12b); - __m256 _tmp2m = _mm256_sub_ps(_tmp12a, _tmp12b); - _mm256_store_ps(tmp[1][m], _tmp1m); - _mm256_store_ps(tmp[2][m], _tmp2m); - - __m256 _tmp34a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _r04, _mm256_comp_fmadd_ps(_mm256_set1_ps(0.25f), _r02, _r06)); - __m256 _tmp34b = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _r05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _r03, _mm256_mul_ps(_r01, _mm256_set1_ps(0.5f)))); - - __m256 _tmp3m = _mm256_add_ps(_tmp34a, _tmp34b); - __m256 _tmp4m = _mm256_sub_ps(_tmp34a, _tmp34b); - _mm256_store_ps(tmp[3][m], _tmp3m); - _mm256_store_ps(tmp[4][m], _tmp4m); - - __m256 _tmp56a = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _r04, _r02), _r06); - __m256 _tmp56b = _mm256_comp_fmadd_ps(_mm256_set1_ps(0.5f), _r05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _r03, _mm256_mul_ps(_r01, _mm256_set1_ps(2.f)))); - - __m256 _tmp5m = _mm256_add_ps(_tmp56a, _tmp56b); - __m256 _tmp6m = _mm256_sub_ps(_tmp56a, _tmp56b); - _mm256_store_ps(tmp[5][m], _tmp5m); - _mm256_store_ps(tmp[6][m], _tmp6m); - - r0 += w * 8; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 8; - float* r0_tm_1 = r0_tm_0 + tiles * 8; - float* r0_tm_2 = r0_tm_0 + tiles * 16; - float* r0_tm_3 = r0_tm_0 + tiles * 24; - float* r0_tm_4 = r0_tm_0 + tiles * 32; - float* r0_tm_5 = r0_tm_0 + tiles * 40; - float* r0_tm_6 = r0_tm_0 + tiles * 48; - float* r0_tm_7 = r0_tm_0 + tiles * 56; - - for (int m = 0; m < 8; m++) - { - __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); - __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); - __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); - __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); - __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); - __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); - __m256 _tmp06 = _mm256_load_ps(tmp[m][6]); - __m256 _tmp07 = _mm256_load_ps(tmp[m][7]); - - __m256 _r0tm0 = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_tmp04, _tmp02), _mm256_sub_ps(_tmp00, _tmp06)); - __m256 _r0tm7 = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_tmp03, _tmp05), _mm256_sub_ps(_tmp07, _tmp01)); - - __m256 _tmp12a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _tmp04, _mm256_add_ps(_tmp02, _tmp06)); - __m256 _tmp12b = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _tmp03, _mm256_add_ps(_tmp01, _tmp05)); - - __m256 _r0tm1 = _mm256_add_ps(_tmp12a, _tmp12b); - __m256 _r0tm2 = _mm256_sub_ps(_tmp12a, _tmp12b); - - __m256 _tmp34a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _tmp04, _mm256_comp_fmadd_ps(_mm256_set1_ps(0.25f), _tmp02, _tmp06)); - __m256 _tmp34b = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _tmp03, _mm256_mul_ps(_tmp01, _mm256_set1_ps(0.5f)))); - - __m256 _r0tm3 = _mm256_add_ps(_tmp34a, _tmp34b); - __m256 _r0tm4 = _mm256_sub_ps(_tmp34a, _tmp34b); - - __m256 _tmp56a = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _tmp04, _tmp02), _tmp06); - __m256 _tmp56b = _mm256_comp_fmadd_ps(_mm256_set1_ps(0.5f), _tmp05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _tmp03, _mm256_mul_ps(_tmp01, _mm256_set1_ps(2.f)))); - - __m256 _r0tm5 = _mm256_add_ps(_tmp56a, _tmp56b); - __m256 _r0tm6 = _mm256_sub_ps(_tmp56a, _tmp56b); - - _mm256_store_ps(r0_tm_0, _r0tm0); - _mm256_store_ps(r0_tm_1, _r0tm1); - _mm256_store_ps(r0_tm_2, _r0tm2); - _mm256_store_ps(r0_tm_3, _r0tm3); - _mm256_store_ps(r0_tm_4, _r0tm4); - _mm256_store_ps(r0_tm_5, _r0tm5); - _mm256_store_ps(r0_tm_6, _r0tm6); - _mm256_store_ps(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 64; - r0_tm_1 += tiles * 64; - r0_tm_2 += tiles * 64; - r0_tm_3 += tiles * 64; - r0_tm_4 += tiles * 64; - r0_tm_5 += tiles * 64; - r0_tm_6 += tiles * 64; - r0_tm_7 += tiles * 64; - } - } - } - } + conv3x3s1_winograd64_transform_input_pack8_avx(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input + // BEGIN dot Mat top_blob_tm; { @@ -1622,143 +1467,7 @@ static void conv3x3s1_winograd64_pack8_avx(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_setzero_ps(); - -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[6][8][8]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 8; - const float* output0_tm_1 = output0_tm_0 + tiles * 8; - const float* output0_tm_2 = output0_tm_0 + tiles * 16; - const float* output0_tm_3 = output0_tm_0 + tiles * 24; - const float* output0_tm_4 = output0_tm_0 + tiles * 32; - const float* output0_tm_5 = output0_tm_0 + tiles * 40; - const float* output0_tm_6 = output0_tm_0 + tiles * 48; - const float* output0_tm_7 = output0_tm_0 + tiles * 56; - - float* output0 = out0.row(i * 6) + (j * 6) * 8; - - // TODO neon optimize - for (int m = 0; m < 8; m++) - { - __m256 _out0tm0 = _mm256_load_ps(output0_tm_0); - __m256 _out0tm1 = _mm256_load_ps(output0_tm_1); - __m256 _out0tm2 = _mm256_load_ps(output0_tm_2); - __m256 _out0tm3 = _mm256_load_ps(output0_tm_3); - __m256 _out0tm4 = _mm256_load_ps(output0_tm_4); - __m256 _out0tm5 = _mm256_load_ps(output0_tm_5); - __m256 _out0tm6 = _mm256_load_ps(output0_tm_6); - __m256 _out0tm7 = _mm256_load_ps(output0_tm_7); - - __m256 _tmp024a = _mm256_add_ps(_out0tm1, _out0tm2); - __m256 _tmp135a = _mm256_sub_ps(_out0tm1, _out0tm2); - - __m256 _tmp024b = _mm256_add_ps(_out0tm3, _out0tm4); - __m256 _tmp135b = _mm256_sub_ps(_out0tm3, _out0tm4); - - __m256 _tmp024c = _mm256_add_ps(_out0tm5, _out0tm6); - __m256 _tmp135c = _mm256_sub_ps(_out0tm5, _out0tm6); - - __m256 _tmp0m = _mm256_add_ps(_mm256_add_ps(_out0tm0, _tmp024a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp024c, _tmp024b)); - __m256 _tmp2m = _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp024b, _tmp024a)); - __m256 _tmp4m = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp024b, _tmp024a)); - _mm256_store_ps(tmp[0][m], _tmp0m); - _mm256_store_ps(tmp[2][m], _tmp2m); - _mm256_store_ps(tmp[4][m], _tmp4m); - - __m256 _tmp1m = _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp135b, _tmp135a)); - __m256 _tmp3m = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp135b, _tmp135a)); - __m256 _tmp5m = _mm256_add_ps(_mm256_add_ps(_out0tm7, _tmp135a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp135b, _tmp135c)); - _mm256_store_ps(tmp[1][m], _tmp1m); - _mm256_store_ps(tmp[3][m], _tmp3m); - _mm256_store_ps(tmp[5][m], _tmp5m); - - output0_tm_0 += tiles * 64; - output0_tm_1 += tiles * 64; - output0_tm_2 += tiles * 64; - output0_tm_3 += tiles * 64; - output0_tm_4 += tiles * 64; - output0_tm_5 += tiles * 64; - output0_tm_6 += tiles * 64; - output0_tm_7 += tiles * 64; - } - - for (int m = 0; m < 6; m++) - { - __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); - __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); - __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); - __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); - __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); - __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); - __m256 _tmp06 = _mm256_load_ps(tmp[m][6]); - __m256 _tmp07 = _mm256_load_ps(tmp[m][7]); - - __m256 _tmp024a = _mm256_add_ps(_tmp01, _tmp02); - __m256 _tmp135a = _mm256_sub_ps(_tmp01, _tmp02); - - __m256 _tmp024b = _mm256_add_ps(_tmp03, _tmp04); - __m256 _tmp135b = _mm256_sub_ps(_tmp03, _tmp04); - - __m256 _tmp024c = _mm256_add_ps(_tmp05, _tmp06); - __m256 _tmp135c = _mm256_sub_ps(_tmp05, _tmp06); - - __m256 _out00 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp00, _tmp024a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp024c, _tmp024b))); - __m256 _out02 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp024b, _tmp024a))); - __m256 _out04 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp024b, _tmp024a))); - _mm256_store_ps(output0, _out00); - _mm256_store_ps(output0 + 16, _out02); - _mm256_store_ps(output0 + 32, _out04); - - __m256 _out01 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp135b, _tmp135a))); - __m256 _out03 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp135b, _tmp135a))); - __m256 _out05 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp07, _tmp135a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp135b, _tmp135c))); - _mm256_store_ps(output0 + 8, _out01); - _mm256_store_ps(output0 + 24, _out03); - _mm256_store_ps(output0 + 40, _out05); - - output0 += outw * 8; - } - } - } - } + conv3x3s1_winograd64_transform_output_pack8_avx(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output @@ -1844,7 +1553,7 @@ static void conv3x3s1_winograd42_transform_kernel_pack8_avx(const Mat& kernel, M } } -static void conv3x3s1_winograd42_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd42_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -1866,120 +1575,15 @@ static void conv3x3s1_winograd42_pack8_avx(const Mat& bottom_blob, Mat& top_blob h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - - const int tiles = w_tm / 6 * h_tm / 6; - - bottom_blob_tm.create(tiles, 36, inch, 4u * elempack, elempack, opt.workspace_allocator); - - // const float itm[4][4] = { - // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, - // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, - // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, - // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, - // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} - // }; - - // 0 = 4 * r00 - 5 * r02 + r04 - // 1 = -4 * (r01 + r02) + r04 + r03 - // 2 = 4 * (r01 - r02) + r04 - r03 - // 3 = -2 * (r01 - r03) + r04 - r02 - // 4 = 2 * (r01 - r03) + r04 - r02 - // 5 = 4 * r01 - 5 * r03 + r05 + int w_tiles = outw / 4; + int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[6][6][8]; - - // tile - for (int i = 0; i < h_tm / 6; i++) - { - for (int j = 0; j < w_tm / 6; j++) - { - const float* r0 = img0.row(i * 4) + (j * 4) * 8; - - for (int m = 0; m < 6; m++) - { - __m256 _r00 = _mm256_load_ps(r0); - __m256 _r01 = _mm256_load_ps(r0 + 8); - __m256 _r02 = _mm256_load_ps(r0 + 8 * 2); - __m256 _r03 = _mm256_load_ps(r0 + 8 * 3); - __m256 _r04 = _mm256_load_ps(r0 + 8 * 4); - __m256 _r05 = _mm256_load_ps(r0 + 8 * 5); - - __m256 _tmp0m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _r02, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _r00, _r04)); - __m256 _tmp1m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.f), _mm256_add_ps(_r01, _r02), _mm256_add_ps(_r04, _r03)); - __m256 _tmp2m = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_sub_ps(_r01, _r02), _mm256_sub_ps(_r04, _r03)); - __m256 _tmp3m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.f), _mm256_sub_ps(_r01, _r03), _mm256_sub_ps(_r04, _r02)); - __m256 _tmp4m = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _mm256_sub_ps(_r01, _r03), _mm256_sub_ps(_r04, _r02)); - __m256 _tmp5m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _r03, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _r01, _r05)); - - _mm256_store_ps(tmp[0][m], _tmp0m); - _mm256_store_ps(tmp[1][m], _tmp1m); - _mm256_store_ps(tmp[2][m], _tmp2m); - _mm256_store_ps(tmp[3][m], _tmp3m); - _mm256_store_ps(tmp[4][m], _tmp4m); - _mm256_store_ps(tmp[5][m], _tmp5m); - - r0 += w * 8; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * 8; - float* r0_tm_1 = r0_tm_0 + tiles * 8; - float* r0_tm_2 = r0_tm_0 + tiles * 8 * 2; - float* r0_tm_3 = r0_tm_0 + tiles * 8 * 3; - float* r0_tm_4 = r0_tm_0 + tiles * 8 * 4; - float* r0_tm_5 = r0_tm_0 + tiles * 8 * 5; - - for (int m = 0; m < 6; m++) - { - __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); - __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); - __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); - __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); - __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); - __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); - - __m256 _r0tm0 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _tmp02, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp00, _tmp04)); - __m256 _r0tm1 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.f), _mm256_add_ps(_tmp01, _tmp02), _mm256_add_ps(_tmp04, _tmp03)); - __m256 _r0tm2 = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_sub_ps(_tmp01, _tmp02), _mm256_sub_ps(_tmp04, _tmp03)); - __m256 _r0tm3 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.f), _mm256_sub_ps(_tmp01, _tmp03), _mm256_sub_ps(_tmp04, _tmp02)); - __m256 _r0tm4 = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _mm256_sub_ps(_tmp01, _tmp03), _mm256_sub_ps(_tmp04, _tmp02)); - __m256 _r0tm5 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _tmp03, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp01, _tmp05)); - - _mm256_store_ps(r0_tm_0, _r0tm0); - _mm256_store_ps(r0_tm_1, _r0tm1); - _mm256_store_ps(r0_tm_2, _r0tm2); - _mm256_store_ps(r0_tm_3, _r0tm3); - _mm256_store_ps(r0_tm_4, _r0tm4); - _mm256_store_ps(r0_tm_5, _r0tm5); - - r0_tm_0 += tiles * 8 * 6; - r0_tm_1 += tiles * 8 * 6; - r0_tm_2 += tiles * 8 * 6; - r0_tm_3 += tiles * 8 * 6; - r0_tm_4 += tiles * 8 * 6; - r0_tm_5 += tiles * 8 * 6; - } - } - } - } + bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator); + conv3x3s1_winograd42_transform_input_pack8_avx(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -2473,118 +2077,7 @@ static void conv3x3s1_winograd42_pack8_avx(const Mat& bottom_blob, Mat& top_blob top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator); } { - // const float otm[4][6] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} - // }; - - // 0 = r00 + (r01 + r02) + (r03 + r04) - // 1 = (r01 - r02) + (r03 - r04) * 2 - // 2 = (r01 + r02) + (r03 + r04) * 4 - // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 - - int w_tm = outw / 4 * 6; - int h_tm = outh / 4 * 6; - const int tiles = w_tm / 6 * h_tm / 6; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - // const float bias0 = bias ? bias[p] : 0.f; - __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_setzero_ps(); - -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[4][6][8]; - - // tile - for (int i = 0; i < outh / 4; i++) - { - for (int j = 0; j < outw / 4; j++) - { - // top_blob_tm.create(tiles, 36, outch, elemsize, elempack); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * 8; - const float* output0_tm_1 = output0_tm_0 + tiles * 8; - const float* output0_tm_2 = output0_tm_0 + tiles * 8 * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 8 * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 8 * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 8 * 5; - - float* output0 = out0.row(i * 4) + (j * 4) * 8; - - // TODO msa optimize - for (int m = 0; m < 6; m++) - { - __m256 _out0tm0 = _mm256_load_ps(output0_tm_0); - __m256 _out0tm1 = _mm256_load_ps(output0_tm_1); - __m256 _out0tm2 = _mm256_load_ps(output0_tm_2); - __m256 _out0tm3 = _mm256_load_ps(output0_tm_3); - __m256 _out0tm4 = _mm256_load_ps(output0_tm_4); - __m256 _out0tm5 = _mm256_load_ps(output0_tm_5); - - __m256 _tmp02a = _mm256_add_ps(_out0tm1, _out0tm2); - __m256 _tmp13a = _mm256_sub_ps(_out0tm1, _out0tm2); - - __m256 _tmp02b = _mm256_add_ps(_out0tm3, _out0tm4); - __m256 _tmp13b = _mm256_sub_ps(_out0tm3, _out0tm4); - - __m256 _tmp0m = _mm256_add_ps(_mm256_add_ps(_out0tm0, _tmp02a), _tmp02b); - __m256 _tmp1m = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp13b, _tmp13a); - __m256 _tmp2m = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp02b, _tmp02a); - __m256 _tmp3m = _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp13b, _mm256_add_ps(_out0tm5, _tmp13a)); - - _mm256_store_ps(tmp[0][m], _tmp0m); - _mm256_store_ps(tmp[1][m], _tmp1m); - _mm256_store_ps(tmp[2][m], _tmp2m); - _mm256_store_ps(tmp[3][m], _tmp3m); - - output0_tm_0 += tiles * 8 * 6; - output0_tm_1 += tiles * 8 * 6; - output0_tm_2 += tiles * 8 * 6; - output0_tm_3 += tiles * 8 * 6; - output0_tm_4 += tiles * 8 * 6; - output0_tm_5 += tiles * 8 * 6; - } - - for (int m = 0; m < 4; m++) - { - __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); - __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); - __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); - __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); - __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); - __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); - - __m256 _tmp02a = _mm256_add_ps(_tmp01, _tmp02); - __m256 _tmp13a = _mm256_sub_ps(_tmp01, _tmp02); - - __m256 _tmp02b = _mm256_add_ps(_tmp03, _tmp04); - __m256 _tmp13b = _mm256_sub_ps(_tmp03, _tmp04); - - __m256 _out00 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp00, _tmp02a), _tmp02b)); - __m256 _out01 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp13b, _tmp13a)); - __m256 _out02 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp02b, _tmp02a)); - __m256 _out03 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp13b, _mm256_add_ps(_tmp05, _tmp13a))); - - _mm256_store_ps(output0, _out00); - _mm256_store_ps(output0 + 8, _out01); - _mm256_store_ps(output0 + 8 * 2, _out02); - _mm256_store_ps(output0 + 8 * 3, _out03); - - output0 += outw * 8; - } - } - } - } + conv3x3s1_winograd42_transform_output_pack8_avx(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/x86/convolution_3x3_pack8to1.h b/src/layer/x86/convolution_3x3_pack8to1.h index a163ffa4af9..d813a0dc37d 100644 --- a/src/layer/x86/convolution_3x3_pack8to1.h +++ b/src/layer/x86/convolution_3x3_pack8to1.h @@ -204,7 +204,7 @@ static void conv3x3s1_winograd64_transform_kernel_pack8to1_avx(const Mat& kernel } } -static void conv3x3s1_winograd64_pack8to1_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt) +static void conv3x3s1_winograd64_pack8to1_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& bias, const Option& opt) { int w = bottom_blob.w; int h = bottom_blob.h; @@ -226,172 +226,15 @@ static void conv3x3s1_winograd64_pack8to1_avx(const Mat& bottom_blob, Mat& top_b h = outh + 2; copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt); - const float* bias = _bias; - // BEGIN transform input Mat bottom_blob_tm; { - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - - const int tiles = w_tm / 8 * h_tm / 8; + int w_tiles = outw / 6; + int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator); - - // const float itm[8][8] = { - // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, - // - // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, - // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, - // - // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, - // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, - // - // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, - // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, - // - // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} - // }; - - // 0 = r00 - r06 + (r04 - r02) * 5.25 - // 7 = r07 - r01 + (r03 - r05) * 5.25 - - // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) - // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) - - // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) - // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) - - // reuse r04 * 1.25 - // reuse r03 * 2.5 - // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) - // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < inch; q++) - { - const Mat img0 = bottom_blob_bordered.channel(q); - Mat img0_tm = bottom_blob_tm.channel(q); - -#ifdef _MSC_VER - __declspec(align(32)) -#else - __attribute__((aligned(32))) -#endif - float tmp[8][8][8]; - - // tile - for (int i = 0; i < h_tm / 8; i++) - { - for (int j = 0; j < w_tm / 8; j++) - { - const float* r0 = img0.row(i * 6) + (j * 6) * 8; - - for (int m = 0; m < 8; m++) - { - __m256 _r00 = _mm256_load_ps(r0); - __m256 _r01 = _mm256_load_ps(r0 + 8); - __m256 _r02 = _mm256_load_ps(r0 + 16); - __m256 _r03 = _mm256_load_ps(r0 + 24); - __m256 _r04 = _mm256_load_ps(r0 + 32); - __m256 _r05 = _mm256_load_ps(r0 + 40); - __m256 _r06 = _mm256_load_ps(r0 + 48); - __m256 _r07 = _mm256_load_ps(r0 + 56); - - __m256 _tmp0m = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_r04, _r02), _mm256_sub_ps(_r00, _r06)); - __m256 _tmp7m = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_r03, _r05), _mm256_sub_ps(_r07, _r01)); - _mm256_store_ps(tmp[0][m], _tmp0m); - _mm256_store_ps(tmp[7][m], _tmp7m); - - __m256 _tmp12a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _r04, _mm256_add_ps(_r02, _r06)); - __m256 _tmp12b = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _r03, _mm256_add_ps(_r01, _r05)); - - __m256 _tmp1m = _mm256_add_ps(_tmp12a, _tmp12b); - __m256 _tmp2m = _mm256_sub_ps(_tmp12a, _tmp12b); - _mm256_store_ps(tmp[1][m], _tmp1m); - _mm256_store_ps(tmp[2][m], _tmp2m); - - __m256 _tmp34a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _r04, _mm256_comp_fmadd_ps(_mm256_set1_ps(0.25f), _r02, _r06)); - __m256 _tmp34b = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _r05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _r03, _mm256_mul_ps(_r01, _mm256_set1_ps(0.5f)))); - - __m256 _tmp3m = _mm256_add_ps(_tmp34a, _tmp34b); - __m256 _tmp4m = _mm256_sub_ps(_tmp34a, _tmp34b); - _mm256_store_ps(tmp[3][m], _tmp3m); - _mm256_store_ps(tmp[4][m], _tmp4m); - - __m256 _tmp56a = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _r04, _r02), _r06); - __m256 _tmp56b = _mm256_comp_fmadd_ps(_mm256_set1_ps(0.5f), _r05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _r03, _mm256_mul_ps(_r01, _mm256_set1_ps(2.f)))); - - __m256 _tmp5m = _mm256_add_ps(_tmp56a, _tmp56b); - __m256 _tmp6m = _mm256_sub_ps(_tmp56a, _tmp56b); - _mm256_store_ps(tmp[5][m], _tmp5m); - _mm256_store_ps(tmp[6][m], _tmp6m); - - r0 += w * 8; - } - - float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 8; - float* r0_tm_1 = r0_tm_0 + tiles * 8; - float* r0_tm_2 = r0_tm_0 + tiles * 16; - float* r0_tm_3 = r0_tm_0 + tiles * 24; - float* r0_tm_4 = r0_tm_0 + tiles * 32; - float* r0_tm_5 = r0_tm_0 + tiles * 40; - float* r0_tm_6 = r0_tm_0 + tiles * 48; - float* r0_tm_7 = r0_tm_0 + tiles * 56; - - for (int m = 0; m < 8; m++) - { - __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); - __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); - __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); - __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); - __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); - __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); - __m256 _tmp06 = _mm256_load_ps(tmp[m][6]); - __m256 _tmp07 = _mm256_load_ps(tmp[m][7]); - - __m256 _r0tm0 = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_tmp04, _tmp02), _mm256_sub_ps(_tmp00, _tmp06)); - __m256 _r0tm7 = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_tmp03, _tmp05), _mm256_sub_ps(_tmp07, _tmp01)); - - __m256 _tmp12a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _tmp04, _mm256_add_ps(_tmp02, _tmp06)); - __m256 _tmp12b = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _tmp03, _mm256_add_ps(_tmp01, _tmp05)); - - __m256 _r0tm1 = _mm256_add_ps(_tmp12a, _tmp12b); - __m256 _r0tm2 = _mm256_sub_ps(_tmp12a, _tmp12b); - - __m256 _tmp34a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _tmp04, _mm256_comp_fmadd_ps(_mm256_set1_ps(0.25f), _tmp02, _tmp06)); - __m256 _tmp34b = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _tmp03, _mm256_mul_ps(_tmp01, _mm256_set1_ps(0.5f)))); - - __m256 _r0tm3 = _mm256_add_ps(_tmp34a, _tmp34b); - __m256 _r0tm4 = _mm256_sub_ps(_tmp34a, _tmp34b); - - __m256 _tmp56a = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _tmp04, _tmp02), _tmp06); - __m256 _tmp56b = _mm256_comp_fmadd_ps(_mm256_set1_ps(0.5f), _tmp05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _tmp03, _mm256_mul_ps(_tmp01, _mm256_set1_ps(2.f)))); - - __m256 _r0tm5 = _mm256_add_ps(_tmp56a, _tmp56b); - __m256 _r0tm6 = _mm256_sub_ps(_tmp56a, _tmp56b); - - _mm256_store_ps(r0_tm_0, _r0tm0); - _mm256_store_ps(r0_tm_1, _r0tm1); - _mm256_store_ps(r0_tm_2, _r0tm2); - _mm256_store_ps(r0_tm_3, _r0tm3); - _mm256_store_ps(r0_tm_4, _r0tm4); - _mm256_store_ps(r0_tm_5, _r0tm5); - _mm256_store_ps(r0_tm_6, _r0tm6); - _mm256_store_ps(r0_tm_7, _r0tm7); - - r0_tm_0 += tiles * 64; - r0_tm_1 += tiles * 64; - r0_tm_2 += tiles * 64; - r0_tm_3 += tiles * 64; - r0_tm_4 += tiles * 64; - r0_tm_5 += tiles * 64; - r0_tm_6 += tiles * 64; - r0_tm_7 += tiles * 64; - } - } - } - } + conv3x3s1_winograd64_transform_input_pack8_avx(bottom_blob_bordered, bottom_blob_tm, opt); } bottom_blob_bordered = Mat(); // END transform input @@ -705,110 +548,7 @@ static void conv3x3s1_winograd64_pack8to1_avx(const Mat& bottom_blob, Mat& top_b top_blob_bordered.create(outw, outh, outch, 4u, 1, opt.workspace_allocator); } { - // const float otm[6][8] = { - // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, - // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, - // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} - // }; - - // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 - // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 - // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 - // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 - // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 - // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) - - int w_tm = outw / 6 * 8; - int h_tm = outh / 6 * 8; - const int tiles = w_tm / 8 * h_tm / 8; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int p = 0; p < outch; p++) - { - const Mat out0_tm = top_blob_tm.channel(p); - Mat out0 = top_blob_bordered.channel(p); - - const float bias0 = bias ? bias[p] : 0.f; - - float tmp[6][8]; - - // tile - for (int i = 0; i < outh / 6; i++) - { - for (int j = 0; j < outw / 6; j++) - { - // top_blob_tm.create(tiles, 64, outch, 4u, 1, opt.workspace_allocator); - - const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 1; - const float* output0_tm_1 = output0_tm_0 + tiles * 1; - const float* output0_tm_2 = output0_tm_0 + tiles * 2; - const float* output0_tm_3 = output0_tm_0 + tiles * 3; - const float* output0_tm_4 = output0_tm_0 + tiles * 4; - const float* output0_tm_5 = output0_tm_0 + tiles * 5; - const float* output0_tm_6 = output0_tm_0 + tiles * 6; - const float* output0_tm_7 = output0_tm_0 + tiles * 7; - - // TODO sse optimize - for (int m = 0; m < 8; m++) - { - float tmp024a = output0_tm_1[0] + output0_tm_2[0]; - float tmp135a = output0_tm_1[0] - output0_tm_2[0]; - - float tmp024b = output0_tm_3[0] + output0_tm_4[0]; - float tmp135b = output0_tm_3[0] - output0_tm_4[0]; - - float tmp024c = output0_tm_5[0] + output0_tm_6[0]; - float tmp135c = output0_tm_5[0] - output0_tm_6[0]; - - tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; - tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; - tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; - tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; - tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; - - output0_tm_0 += tiles * 8; - output0_tm_1 += tiles * 8; - output0_tm_2 += tiles * 8; - output0_tm_3 += tiles * 8; - output0_tm_4 += tiles * 8; - output0_tm_5 += tiles * 8; - output0_tm_6 += tiles * 8; - output0_tm_7 += tiles * 8; - } - - float* output0 = out0.row(i * 6) + j * 6; - - for (int m = 0; m < 6; m++) - { - const float* tmp0 = tmp[m]; - - float tmp024a = tmp0[1] + tmp0[2]; - float tmp135a = tmp0[1] - tmp0[2]; - - float tmp024b = tmp0[3] + tmp0[4]; - float tmp135b = tmp0[3] - tmp0[4]; - - float tmp024c = tmp0[5] + tmp0[6]; - float tmp135c = tmp0[5] - tmp0[6]; - - output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; - output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; - output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; - - output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; - output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; - output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; - - output0 += outw; - } - } - } - } + conv3x3s1_winograd64_transform_output_sse(top_blob_tm, top_blob_bordered, bias, opt); } // END transform output diff --git a/src/layer/x86/convolution_winograd_transform.h b/src/layer/x86/convolution_winograd_transform.h new file mode 100644 index 00000000000..db4c7f4afb9 --- /dev/null +++ b/src/layer/x86/convolution_winograd_transform.h @@ -0,0 +1,125 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_output_sse(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + const float bias0 = biasptr ? biasptr[p] : 0.f; + + float tmp[6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 1; + const float* output0_tm_1 = output0_tm_0 + tiles * 1; + const float* output0_tm_2 = output0_tm_0 + tiles * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 5; + const float* output0_tm_6 = output0_tm_0 + tiles * 6; + const float* output0_tm_7 = output0_tm_0 + tiles * 7; + + // TODO sse optimize + for (int m = 0; m < 8; m++) + { + float tmp024a = output0_tm_1[0] + output0_tm_2[0]; + float tmp135a = output0_tm_1[0] - output0_tm_2[0]; + + float tmp024b = output0_tm_3[0] + output0_tm_4[0]; + float tmp135b = output0_tm_3[0] - output0_tm_4[0]; + + float tmp024c = output0_tm_5[0] + output0_tm_6[0]; + float tmp135c = output0_tm_5[0] - output0_tm_6[0]; + + tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32; + tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8; + tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c; + + tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16; + tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4; + tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c; + + output0_tm_0 += tiles * 8; + output0_tm_1 += tiles * 8; + output0_tm_2 += tiles * 8; + output0_tm_3 += tiles * 8; + output0_tm_4 += tiles * 8; + output0_tm_5 += tiles * 8; + output0_tm_6 += tiles * 8; + output0_tm_7 += tiles * 8; + } + + float* output0 = out0.row(i * 6) + j * 6; + + for (int m = 0; m < 6; m++) + { + const float* tmp0 = tmp[m]; + + float tmp024a = tmp0[1] + tmp0[2]; + float tmp135a = tmp0[1] - tmp0[2]; + + float tmp024b = tmp0[3] + tmp0[4]; + float tmp135b = tmp0[3] - tmp0[4]; + + float tmp024c = tmp0[5] + tmp0[6]; + float tmp135c = tmp0[5] - tmp0[6]; + + output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32; + output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8; + output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c; + + output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16; + output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4; + output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c; + + output0 += outw; + } + } + } + } +} diff --git a/src/layer/x86/convolution_winograd_transform_pack16.h b/src/layer/x86/convolution_winograd_transform_pack16.h new file mode 100644 index 00000000000..d055f9eda6e --- /dev/null +++ b/src/layer/x86/convolution_winograd_transform_pack16.h @@ -0,0 +1,555 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack16_avx512(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[8][8][16]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 6) + (j * 6) * 16; + + for (int m = 0; m < 8; m++) + { + __m512 _r00 = _mm512_load_ps(r0); + __m512 _r01 = _mm512_load_ps(r0 + 16); + __m512 _r02 = _mm512_load_ps(r0 + 16 * 2); + __m512 _r03 = _mm512_load_ps(r0 + 16 * 3); + __m512 _r04 = _mm512_load_ps(r0 + 16 * 4); + __m512 _r05 = _mm512_load_ps(r0 + 16 * 5); + __m512 _r06 = _mm512_load_ps(r0 + 16 * 6); + __m512 _r07 = _mm512_load_ps(r0 + 16 * 7); + + __m512 _tmp0m = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_r04, _r02), _mm512_sub_ps(_r00, _r06)); + __m512 _tmp7m = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_r03, _r05), _mm512_sub_ps(_r07, _r01)); + _mm512_store_ps(tmp[0][m], _tmp0m); + _mm512_store_ps(tmp[7][m], _tmp7m); + + __m512 _tmp12a = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _r04, _mm512_add_ps(_r02, _r06)); + __m512 _tmp12b = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _r03, _mm512_add_ps(_r01, _r05)); + + __m512 _tmp1m = _mm512_add_ps(_tmp12a, _tmp12b); + __m512 _tmp2m = _mm512_sub_ps(_tmp12a, _tmp12b); + _mm512_store_ps(tmp[1][m], _tmp1m); + _mm512_store_ps(tmp[2][m], _tmp2m); + + __m512 _tmp34a = _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _r04, _mm512_fmadd_ps(_mm512_set1_ps(0.25f), _r02, _r06)); + __m512 _tmp34b = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _r05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _r03, _mm512_mul_ps(_r01, _mm512_set1_ps(0.5f)))); + + __m512 _tmp3m = _mm512_add_ps(_tmp34a, _tmp34b); + __m512 _tmp4m = _mm512_sub_ps(_tmp34a, _tmp34b); + _mm512_store_ps(tmp[3][m], _tmp3m); + _mm512_store_ps(tmp[4][m], _tmp4m); + + __m512 _tmp56a = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _r04, _r02), _r06); + __m512 _tmp56b = _mm512_fmadd_ps(_mm512_set1_ps(0.5f), _r05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _r03, _mm512_mul_ps(_r01, _mm512_set1_ps(2.f)))); + + __m512 _tmp5m = _mm512_add_ps(_tmp56a, _tmp56b); + __m512 _tmp6m = _mm512_sub_ps(_tmp56a, _tmp56b); + _mm512_store_ps(tmp[5][m], _tmp5m); + _mm512_store_ps(tmp[6][m], _tmp6m); + + r0 += w * 16; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 16; + float* r0_tm_1 = r0_tm_0 + tiles * 16; + float* r0_tm_2 = r0_tm_0 + tiles * 16 * 2; + float* r0_tm_3 = r0_tm_0 + tiles * 16 * 3; + float* r0_tm_4 = r0_tm_0 + tiles * 16 * 4; + float* r0_tm_5 = r0_tm_0 + tiles * 16 * 5; + float* r0_tm_6 = r0_tm_0 + tiles * 16 * 6; + float* r0_tm_7 = r0_tm_0 + tiles * 16 * 7; + + for (int m = 0; m < 8; m++) + { + __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); + __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); + __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); + __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); + __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); + __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); + __m512 _tmp06 = _mm512_load_ps(tmp[m][6]); + __m512 _tmp07 = _mm512_load_ps(tmp[m][7]); + + __m512 _r0tm0 = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_tmp04, _tmp02), _mm512_sub_ps(_tmp00, _tmp06)); + __m512 _r0tm7 = _mm512_fmadd_ps(_mm512_set1_ps(5.25f), _mm512_sub_ps(_tmp03, _tmp05), _mm512_sub_ps(_tmp07, _tmp01)); + + __m512 _tmp12a = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _tmp04, _mm512_add_ps(_tmp02, _tmp06)); + __m512 _tmp12b = _mm512_fmadd_ps(_mm512_set1_ps(-4.25f), _tmp03, _mm512_add_ps(_tmp01, _tmp05)); + + __m512 _r0tm1 = _mm512_add_ps(_tmp12a, _tmp12b); + __m512 _r0tm2 = _mm512_sub_ps(_tmp12a, _tmp12b); + + __m512 _tmp34a = _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _tmp04, _mm512_fmadd_ps(_mm512_set1_ps(0.25f), _tmp02, _tmp06)); + __m512 _tmp34b = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _tmp03, _mm512_mul_ps(_tmp01, _mm512_set1_ps(0.5f)))); + + __m512 _r0tm3 = _mm512_add_ps(_tmp34a, _tmp34b); + __m512 _r0tm4 = _mm512_sub_ps(_tmp34a, _tmp34b); + + __m512 _tmp56a = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_fmadd_ps(_mm512_set1_ps(-1.25f), _tmp04, _tmp02), _tmp06); + __m512 _tmp56b = _mm512_fmadd_ps(_mm512_set1_ps(0.5f), _tmp05, _mm512_fmadd_ps(_mm512_set1_ps(-2.5f), _tmp03, _mm512_mul_ps(_tmp01, _mm512_set1_ps(2.f)))); + + __m512 _r0tm5 = _mm512_add_ps(_tmp56a, _tmp56b); + __m512 _r0tm6 = _mm512_sub_ps(_tmp56a, _tmp56b); + + _mm512_store_ps(r0_tm_0, _r0tm0); + _mm512_store_ps(r0_tm_1, _r0tm1); + _mm512_store_ps(r0_tm_2, _r0tm2); + _mm512_store_ps(r0_tm_3, _r0tm3); + _mm512_store_ps(r0_tm_4, _r0tm4); + _mm512_store_ps(r0_tm_5, _r0tm5); + _mm512_store_ps(r0_tm_6, _r0tm6); + _mm512_store_ps(r0_tm_7, _r0tm7); + + r0_tm_0 += tiles * 128; + r0_tm_1 += tiles * 128; + r0_tm_2 += tiles * 128; + r0_tm_3 += tiles * 128; + r0_tm_4 += tiles * 128; + r0_tm_5 += tiles * 128; + r0_tm_6 += tiles * 128; + r0_tm_7 += tiles * 128; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack16_avx512(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + __m512 _bias0 = biasptr ? _mm512_loadu_ps(biasptr + p * 16) : _mm512_setzero_ps(); + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[6][8][16]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 16; + const float* output0_tm_1 = output0_tm_0 + tiles * 16; + const float* output0_tm_2 = output0_tm_0 + tiles * 16 * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 16 * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 16 * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 16 * 5; + const float* output0_tm_6 = output0_tm_0 + tiles * 16 * 6; + const float* output0_tm_7 = output0_tm_0 + tiles * 16 * 7; + + float* output0 = out0.row(i * 6) + (j * 6) * 16; + + for (int m = 0; m < 8; m++) + { + __m512 _out0tm0 = _mm512_load_ps(output0_tm_0); + __m512 _out0tm1 = _mm512_load_ps(output0_tm_1); + __m512 _out0tm2 = _mm512_load_ps(output0_tm_2); + __m512 _out0tm3 = _mm512_load_ps(output0_tm_3); + __m512 _out0tm4 = _mm512_load_ps(output0_tm_4); + __m512 _out0tm5 = _mm512_load_ps(output0_tm_5); + __m512 _out0tm6 = _mm512_load_ps(output0_tm_6); + __m512 _out0tm7 = _mm512_load_ps(output0_tm_7); + + __m512 _tmp024a = _mm512_add_ps(_out0tm1, _out0tm2); + __m512 _tmp135a = _mm512_sub_ps(_out0tm1, _out0tm2); + + __m512 _tmp024b = _mm512_add_ps(_out0tm3, _out0tm4); + __m512 _tmp135b = _mm512_sub_ps(_out0tm3, _out0tm4); + + __m512 _tmp024c = _mm512_add_ps(_out0tm5, _out0tm6); + __m512 _tmp135c = _mm512_sub_ps(_out0tm5, _out0tm6); + + __m512 _tmp0m = _mm512_add_ps(_mm512_add_ps(_out0tm0, _tmp024a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp024c, _tmp024b)); + __m512 _tmp2m = _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp024b, _tmp024a)); + __m512 _tmp4m = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp024b, _tmp024a)); + _mm512_store_ps(tmp[0][m], _tmp0m); + _mm512_store_ps(tmp[2][m], _tmp2m); + _mm512_store_ps(tmp[4][m], _tmp4m); + + __m512 _tmp1m = _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp135b, _tmp135a)); + __m512 _tmp3m = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp135b, _tmp135a)); + __m512 _tmp5m = _mm512_add_ps(_mm512_add_ps(_out0tm7, _tmp135a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp135b, _tmp135c)); + _mm512_store_ps(tmp[1][m], _tmp1m); + _mm512_store_ps(tmp[3][m], _tmp3m); + _mm512_store_ps(tmp[5][m], _tmp5m); + + output0_tm_0 += tiles * 128; + output0_tm_1 += tiles * 128; + output0_tm_2 += tiles * 128; + output0_tm_3 += tiles * 128; + output0_tm_4 += tiles * 128; + output0_tm_5 += tiles * 128; + output0_tm_6 += tiles * 128; + output0_tm_7 += tiles * 128; + } + + for (int m = 0; m < 6; m++) + { + __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); + __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); + __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); + __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); + __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); + __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); + __m512 _tmp06 = _mm512_load_ps(tmp[m][6]); + __m512 _tmp07 = _mm512_load_ps(tmp[m][7]); + + __m512 _tmp024a = _mm512_add_ps(_tmp01, _tmp02); + __m512 _tmp135a = _mm512_sub_ps(_tmp01, _tmp02); + + __m512 _tmp024b = _mm512_add_ps(_tmp03, _tmp04); + __m512 _tmp135b = _mm512_sub_ps(_tmp03, _tmp04); + + __m512 _tmp024c = _mm512_add_ps(_tmp05, _tmp06); + __m512 _tmp135c = _mm512_sub_ps(_tmp05, _tmp06); + + __m512 _out00 = _mm512_add_ps(_bias0, _mm512_add_ps(_mm512_add_ps(_tmp00, _tmp024a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp024c, _tmp024b))); + __m512 _out02 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp024b, _tmp024a))); + __m512 _out04 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp024c, _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp024b, _tmp024a))); + _mm512_store_ps(output0, _out00); + _mm512_store_ps(output0 + 32, _out02); + _mm512_store_ps(output0 + 64, _out04); + + __m512 _out01 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(16.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp135b, _tmp135a))); + __m512 _out03 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp135c, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp135b, _tmp135a))); + __m512 _out05 = _mm512_add_ps(_bias0, _mm512_add_ps(_mm512_add_ps(_tmp07, _tmp135a), _mm512_fmadd_ps(_mm512_set1_ps(32.f), _tmp135b, _tmp135c))); + _mm512_store_ps(output0 + 16, _out01); + _mm512_store_ps(output0 + 48, _out03); + _mm512_store_ps(output0 + 80, _out05); + + output0 += outw * 16; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_pack16_avx512(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[4][4] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[6][6][16]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 4) + (j * 4) * 16; + + for (int m = 0; m < 6; m++) + { + __m512 _r00 = _mm512_load_ps(r0); + __m512 _r01 = _mm512_load_ps(r0 + 16); + __m512 _r02 = _mm512_load_ps(r0 + 16 * 2); + __m512 _r03 = _mm512_load_ps(r0 + 16 * 3); + __m512 _r04 = _mm512_load_ps(r0 + 16 * 4); + __m512 _r05 = _mm512_load_ps(r0 + 16 * 5); + + __m512 _tmp0m = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _r02, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _r00, _r04)); + __m512 _tmp1m = _mm512_fmadd_ps(_mm512_set1_ps(-4.f), _mm512_add_ps(_r01, _r02), _mm512_add_ps(_r04, _r03)); + __m512 _tmp2m = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_sub_ps(_r01, _r02), _mm512_sub_ps(_r04, _r03)); + __m512 _tmp3m = _mm512_fmadd_ps(_mm512_set1_ps(-2.f), _mm512_sub_ps(_r01, _r03), _mm512_sub_ps(_r04, _r02)); + __m512 _tmp4m = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _mm512_sub_ps(_r01, _r03), _mm512_sub_ps(_r04, _r02)); + __m512 _tmp5m = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _r03, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _r01, _r05)); + + _mm512_store_ps(tmp[0][m], _tmp0m); + _mm512_store_ps(tmp[1][m], _tmp1m); + _mm512_store_ps(tmp[2][m], _tmp2m); + _mm512_store_ps(tmp[3][m], _tmp3m); + _mm512_store_ps(tmp[4][m], _tmp4m); + _mm512_store_ps(tmp[5][m], _tmp5m); + + r0 += w * 16; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 16; + float* r0_tm_1 = r0_tm_0 + tiles * 16; + float* r0_tm_2 = r0_tm_0 + tiles * 16 * 2; + float* r0_tm_3 = r0_tm_0 + tiles * 16 * 3; + float* r0_tm_4 = r0_tm_0 + tiles * 16 * 4; + float* r0_tm_5 = r0_tm_0 + tiles * 16 * 5; + + for (int m = 0; m < 6; m++) + { + __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); + __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); + __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); + __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); + __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); + __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); + + __m512 _r0tm0 = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _tmp02, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp00, _tmp04)); + __m512 _r0tm1 = _mm512_fmadd_ps(_mm512_set1_ps(-4.f), _mm512_add_ps(_tmp01, _tmp02), _mm512_add_ps(_tmp04, _tmp03)); + __m512 _r0tm2 = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _mm512_sub_ps(_tmp01, _tmp02), _mm512_sub_ps(_tmp04, _tmp03)); + __m512 _r0tm3 = _mm512_fmadd_ps(_mm512_set1_ps(-2.f), _mm512_sub_ps(_tmp01, _tmp03), _mm512_sub_ps(_tmp04, _tmp02)); + __m512 _r0tm4 = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _mm512_sub_ps(_tmp01, _tmp03), _mm512_sub_ps(_tmp04, _tmp02)); + __m512 _r0tm5 = _mm512_fmadd_ps(_mm512_set1_ps(-5.f), _tmp03, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp01, _tmp05)); + + _mm512_store_ps(r0_tm_0, _r0tm0); + _mm512_store_ps(r0_tm_1, _r0tm1); + _mm512_store_ps(r0_tm_2, _r0tm2); + _mm512_store_ps(r0_tm_3, _r0tm3); + _mm512_store_ps(r0_tm_4, _r0tm4); + _mm512_store_ps(r0_tm_5, _r0tm5); + + r0_tm_0 += tiles * 96; + r0_tm_1 += tiles * 96; + r0_tm_2 += tiles * 96; + r0_tm_3 += tiles * 96; + r0_tm_4 += tiles * 96; + r0_tm_5 += tiles * 96; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_pack16_avx512(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + __m512 _bias0 = biasptr ? _mm512_loadu_ps(biasptr + p * 16) : _mm512_setzero_ps(); + +#ifdef _MSC_VER + __declspec(align(64)) +#else + __attribute__((aligned(64))) +#endif + float tmp[4][6][16]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 16; + const float* output0_tm_1 = output0_tm_0 + tiles * 16; + const float* output0_tm_2 = output0_tm_0 + tiles * 16 * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 16 * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 16 * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 16 * 5; + + float* output0 = out0.row(i * 4) + (j * 4) * 16; + + for (int m = 0; m < 6; m++) + { + __m512 _out0tm0 = _mm512_load_ps(output0_tm_0); + __m512 _out0tm1 = _mm512_load_ps(output0_tm_1); + __m512 _out0tm2 = _mm512_load_ps(output0_tm_2); + __m512 _out0tm3 = _mm512_load_ps(output0_tm_3); + __m512 _out0tm4 = _mm512_load_ps(output0_tm_4); + __m512 _out0tm5 = _mm512_load_ps(output0_tm_5); + + __m512 _tmp02a = _mm512_add_ps(_out0tm1, _out0tm2); + __m512 _tmp13a = _mm512_sub_ps(_out0tm1, _out0tm2); + + __m512 _tmp02b = _mm512_add_ps(_out0tm3, _out0tm4); + __m512 _tmp13b = _mm512_sub_ps(_out0tm3, _out0tm4); + + __m512 _tmp0m = _mm512_add_ps(_mm512_add_ps(_out0tm0, _tmp02a), _tmp02b); + __m512 _tmp1m = _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp13b, _tmp13a); + __m512 _tmp2m = _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp02b, _tmp02a); + __m512 _tmp3m = _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp13b, _mm512_add_ps(_out0tm5, _tmp13a)); + + _mm512_store_ps(tmp[0][m], _tmp0m); + _mm512_store_ps(tmp[1][m], _tmp1m); + _mm512_store_ps(tmp[2][m], _tmp2m); + _mm512_store_ps(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 96; + output0_tm_1 += tiles * 96; + output0_tm_2 += tiles * 96; + output0_tm_3 += tiles * 96; + output0_tm_4 += tiles * 96; + output0_tm_5 += tiles * 96; + } + + for (int m = 0; m < 4; m++) + { + __m512 _tmp00 = _mm512_load_ps(tmp[m][0]); + __m512 _tmp01 = _mm512_load_ps(tmp[m][1]); + __m512 _tmp02 = _mm512_load_ps(tmp[m][2]); + __m512 _tmp03 = _mm512_load_ps(tmp[m][3]); + __m512 _tmp04 = _mm512_load_ps(tmp[m][4]); + __m512 _tmp05 = _mm512_load_ps(tmp[m][5]); + + __m512 _tmp02a = _mm512_add_ps(_tmp01, _tmp02); + __m512 _tmp13a = _mm512_sub_ps(_tmp01, _tmp02); + + __m512 _tmp02b = _mm512_add_ps(_tmp03, _tmp04); + __m512 _tmp13b = _mm512_sub_ps(_tmp03, _tmp04); + + __m512 _out00 = _mm512_add_ps(_bias0, _mm512_add_ps(_mm512_add_ps(_tmp00, _tmp02a), _tmp02b)); + __m512 _out01 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(2.f), _tmp13b, _tmp13a)); + __m512 _out02 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(4.f), _tmp02b, _tmp02a)); + __m512 _out03 = _mm512_add_ps(_bias0, _mm512_fmadd_ps(_mm512_set1_ps(8.f), _tmp13b, _mm512_add_ps(_tmp05, _tmp13a))); + + _mm512_store_ps(output0, _out00); + _mm512_store_ps(output0 + 16, _out01); + _mm512_store_ps(output0 + 16 * 2, _out02); + _mm512_store_ps(output0 + 16 * 3, _out03); + + output0 += outw * 16; + } + } + } + } +} diff --git a/src/layer/x86/convolution_winograd_transform_pack4.h b/src/layer/x86/convolution_winograd_transform_pack4.h new file mode 100644 index 00000000000..96ecf1904cb --- /dev/null +++ b/src/layer/x86/convolution_winograd_transform_pack4.h @@ -0,0 +1,580 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack4_sse(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[8][8][4]; + + __m128 _v5_25 = _mm_set1_ps(5.25f); + __m128 _vm4_25 = _mm_set1_ps(-4.25f); + __m128 _vm1_25 = _mm_set1_ps(-1.25f); + __m128 _v0_25 = _mm_set1_ps(0.25f); + __m128 _vm2_5 = _mm_set1_ps(-2.5f); + __m128 _v0_5 = _mm_set1_ps(0.5f); + __m128 _v2 = _mm_set1_ps(2.f); + __m128 _v4 = _mm_set1_ps(4.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + __m128 _r00 = _mm_load_ps(r0); + __m128 _r01 = _mm_load_ps(r0 + 4); + __m128 _r02 = _mm_load_ps(r0 + 4 * 2); + __m128 _r03 = _mm_load_ps(r0 + 4 * 3); + __m128 _r04 = _mm_load_ps(r0 + 4 * 4); + __m128 _r05 = _mm_load_ps(r0 + 4 * 5); + __m128 _r06 = _mm_load_ps(r0 + 4 * 6); + __m128 _r07 = _mm_load_ps(r0 + 4 * 7); + + __m128 _tmp0m = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_r04, _r02), _mm_sub_ps(_r00, _r06)); + __m128 _tmp7m = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_r03, _r05), _mm_sub_ps(_r07, _r01)); + _mm_store_ps(tmp[0][m], _tmp0m); + _mm_store_ps(tmp[7][m], _tmp7m); + + __m128 _tmp12a = _mm_comp_fmadd_ps(_vm4_25, _r04, _mm_add_ps(_r02, _r06)); + __m128 _tmp12b = _mm_comp_fmadd_ps(_vm4_25, _r03, _mm_add_ps(_r01, _r05)); + + __m128 _tmp1m = _mm_add_ps(_tmp12a, _tmp12b); + __m128 _tmp2m = _mm_sub_ps(_tmp12a, _tmp12b); + _mm_store_ps(tmp[1][m], _tmp1m); + _mm_store_ps(tmp[2][m], _tmp2m); + + __m128 _tmp34a = _mm_comp_fmadd_ps(_vm1_25, _r04, _mm_comp_fmadd_ps(_v0_25, _r02, _r06)); + __m128 _tmp34b = _mm_comp_fmadd_ps(_v2, _r05, _mm_comp_fmadd_ps(_vm2_5, _r03, _mm_mul_ps(_r01, _v0_5))); + + __m128 _tmp3m = _mm_add_ps(_tmp34a, _tmp34b); + __m128 _tmp4m = _mm_sub_ps(_tmp34a, _tmp34b); + _mm_store_ps(tmp[3][m], _tmp3m); + _mm_store_ps(tmp[4][m], _tmp4m); + + __m128 _tmp56a = _mm_comp_fmadd_ps(_v4, _mm_comp_fmadd_ps(_vm1_25, _r04, _r02), _r06); + __m128 _tmp56b = _mm_comp_fmadd_ps(_v0_5, _r05, _mm_comp_fmadd_ps(_vm2_5, _r03, _mm_mul_ps(_r01, _v2))); + + __m128 _tmp5m = _mm_add_ps(_tmp56a, _tmp56b); + __m128 _tmp6m = _mm_sub_ps(_tmp56a, _tmp56b); + _mm_store_ps(tmp[5][m], _tmp5m); + _mm_store_ps(tmp[6][m], _tmp6m); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; + float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; + float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; + float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; + float* r0_tm_6 = r0_tm_0 + tiles * 4 * 6; + float* r0_tm_7 = r0_tm_0 + tiles * 4 * 7; + + for (int m = 0; m < 8; m++) + { + __m128 _tmp00 = _mm_load_ps(tmp[m][0]); + __m128 _tmp01 = _mm_load_ps(tmp[m][1]); + __m128 _tmp02 = _mm_load_ps(tmp[m][2]); + __m128 _tmp03 = _mm_load_ps(tmp[m][3]); + __m128 _tmp04 = _mm_load_ps(tmp[m][4]); + __m128 _tmp05 = _mm_load_ps(tmp[m][5]); + __m128 _tmp06 = _mm_load_ps(tmp[m][6]); + __m128 _tmp07 = _mm_load_ps(tmp[m][7]); + + __m128 _r0tm0 = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_tmp04, _tmp02), _mm_sub_ps(_tmp00, _tmp06)); + __m128 _r0tm7 = _mm_comp_fmadd_ps(_v5_25, _mm_sub_ps(_tmp03, _tmp05), _mm_sub_ps(_tmp07, _tmp01)); + + __m128 _tmp12a = _mm_comp_fmadd_ps(_vm4_25, _tmp04, _mm_add_ps(_tmp02, _tmp06)); + __m128 _tmp12b = _mm_comp_fmadd_ps(_vm4_25, _tmp03, _mm_add_ps(_tmp01, _tmp05)); + + __m128 _r0tm1 = _mm_add_ps(_tmp12a, _tmp12b); + __m128 _r0tm2 = _mm_sub_ps(_tmp12a, _tmp12b); + + __m128 _tmp34a = _mm_comp_fmadd_ps(_vm1_25, _tmp04, _mm_comp_fmadd_ps(_v0_25, _tmp02, _tmp06)); + __m128 _tmp34b = _mm_comp_fmadd_ps(_v2, _tmp05, _mm_comp_fmadd_ps(_vm2_5, _tmp03, _mm_mul_ps(_tmp01, _v0_5))); + + __m128 _r0tm3 = _mm_add_ps(_tmp34a, _tmp34b); + __m128 _r0tm4 = _mm_sub_ps(_tmp34a, _tmp34b); + + __m128 _tmp56a = _mm_comp_fmadd_ps(_v4, _mm_comp_fmadd_ps(_vm1_25, _tmp04, _tmp02), _tmp06); + __m128 _tmp56b = _mm_comp_fmadd_ps(_v0_5, _tmp05, _mm_comp_fmadd_ps(_vm2_5, _tmp03, _mm_mul_ps(_tmp01, _v2))); + + __m128 _r0tm5 = _mm_add_ps(_tmp56a, _tmp56b); + __m128 _r0tm6 = _mm_sub_ps(_tmp56a, _tmp56b); + + _mm_store_ps(r0_tm_0, _r0tm0); + _mm_store_ps(r0_tm_1, _r0tm1); + _mm_store_ps(r0_tm_2, _r0tm2); + _mm_store_ps(r0_tm_3, _r0tm3); + _mm_store_ps(r0_tm_4, _r0tm4); + _mm_store_ps(r0_tm_5, _r0tm5); + _mm_store_ps(r0_tm_6, _r0tm6); + _mm_store_ps(r0_tm_7, _r0tm7); + + r0_tm_0 += tiles * 4 * 8; + r0_tm_1 += tiles * 4 * 8; + r0_tm_2 += tiles * 4 * 8; + r0_tm_3 += tiles * 4 * 8; + r0_tm_4 += tiles * 4 * 8; + r0_tm_5 += tiles * 4 * 8; + r0_tm_6 += tiles * 4 * 8; + r0_tm_7 += tiles * 4 * 8; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack4_sse(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + __m128 _bias0 = biasptr ? _mm_loadu_ps(biasptr + p * 4) : _mm_setzero_ps(); + +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[6][8][4]; + + __m128 _v32 = _mm_set1_ps(32.f); + __m128 _v16 = _mm_set1_ps(16.f); + __m128 _v8 = _mm_set1_ps(8.f); + __m128 _v4 = _mm_set1_ps(4.f); + __m128 _v2 = _mm_set1_ps(2.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; + const float* output0_tm_6 = output0_tm_0 + tiles * 4 * 6; + const float* output0_tm_7 = output0_tm_0 + tiles * 4 * 7; + + float* output0 = out0.row(i * 6) + (j * 6) * 4; + + for (int m = 0; m < 8; m++) + { + __m128 _out0tm0 = _mm_load_ps(output0_tm_0); + __m128 _out0tm1 = _mm_load_ps(output0_tm_1); + __m128 _out0tm2 = _mm_load_ps(output0_tm_2); + __m128 _out0tm3 = _mm_load_ps(output0_tm_3); + __m128 _out0tm4 = _mm_load_ps(output0_tm_4); + __m128 _out0tm5 = _mm_load_ps(output0_tm_5); + __m128 _out0tm6 = _mm_load_ps(output0_tm_6); + __m128 _out0tm7 = _mm_load_ps(output0_tm_7); + + __m128 _tmp024a = _mm_add_ps(_out0tm1, _out0tm2); + __m128 _tmp135a = _mm_sub_ps(_out0tm1, _out0tm2); + + __m128 _tmp024b = _mm_add_ps(_out0tm3, _out0tm4); + __m128 _tmp135b = _mm_sub_ps(_out0tm3, _out0tm4); + + __m128 _tmp024c = _mm_add_ps(_out0tm5, _out0tm6); + __m128 _tmp135c = _mm_sub_ps(_out0tm5, _out0tm6); + + __m128 _tmp0m = _mm_add_ps(_mm_add_ps(_out0tm0, _tmp024a), _mm_comp_fmadd_ps(_v32, _tmp024c, _tmp024b)); + __m128 _tmp2m = _mm_comp_fmadd_ps(_v8, _tmp024c, _mm_comp_fmadd_ps(_v4, _tmp024b, _tmp024a)); + __m128 _tmp4m = _mm_comp_fmadd_ps(_v2, _tmp024c, _mm_comp_fmadd_ps(_v16, _tmp024b, _tmp024a)); + _mm_store_ps(tmp[0][m], _tmp0m); + _mm_store_ps(tmp[2][m], _tmp2m); + _mm_store_ps(tmp[4][m], _tmp4m); + + __m128 _tmp1m = _mm_comp_fmadd_ps(_v16, _tmp135c, _mm_comp_fmadd_ps(_v2, _tmp135b, _tmp135a)); + __m128 _tmp3m = _mm_comp_fmadd_ps(_v4, _tmp135c, _mm_comp_fmadd_ps(_v8, _tmp135b, _tmp135a)); + __m128 _tmp5m = _mm_add_ps(_mm_add_ps(_out0tm7, _tmp135a), _mm_comp_fmadd_ps(_v32, _tmp135b, _tmp135c)); + _mm_store_ps(tmp[1][m], _tmp1m); + _mm_store_ps(tmp[3][m], _tmp3m); + _mm_store_ps(tmp[5][m], _tmp5m); + + output0_tm_0 += tiles * 4 * 8; + output0_tm_1 += tiles * 4 * 8; + output0_tm_2 += tiles * 4 * 8; + output0_tm_3 += tiles * 4 * 8; + output0_tm_4 += tiles * 4 * 8; + output0_tm_5 += tiles * 4 * 8; + output0_tm_6 += tiles * 4 * 8; + output0_tm_7 += tiles * 4 * 8; + } + + for (int m = 0; m < 6; m++) + { + __m128 _tmp00 = _mm_load_ps(tmp[m][0]); + __m128 _tmp01 = _mm_load_ps(tmp[m][1]); + __m128 _tmp02 = _mm_load_ps(tmp[m][2]); + __m128 _tmp03 = _mm_load_ps(tmp[m][3]); + __m128 _tmp04 = _mm_load_ps(tmp[m][4]); + __m128 _tmp05 = _mm_load_ps(tmp[m][5]); + __m128 _tmp06 = _mm_load_ps(tmp[m][6]); + __m128 _tmp07 = _mm_load_ps(tmp[m][7]); + + __m128 _tmp024a = _mm_add_ps(_tmp01, _tmp02); + __m128 _tmp135a = _mm_sub_ps(_tmp01, _tmp02); + + __m128 _tmp024b = _mm_add_ps(_tmp03, _tmp04); + __m128 _tmp135b = _mm_sub_ps(_tmp03, _tmp04); + + __m128 _tmp024c = _mm_add_ps(_tmp05, _tmp06); + __m128 _tmp135c = _mm_sub_ps(_tmp05, _tmp06); + + __m128 _out00 = _mm_add_ps(_bias0, _mm_add_ps(_mm_add_ps(_tmp00, _tmp024a), _mm_comp_fmadd_ps(_v32, _tmp024c, _tmp024b))); + __m128 _out02 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v8, _tmp024c, _mm_comp_fmadd_ps(_v4, _tmp024b, _tmp024a))); + __m128 _out04 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v2, _tmp024c, _mm_comp_fmadd_ps(_v16, _tmp024b, _tmp024a))); + _mm_store_ps(output0, _out00); + _mm_store_ps(output0 + 4 * 2, _out02); + _mm_store_ps(output0 + 4 * 4, _out04); + + __m128 _out01 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v16, _tmp135c, _mm_comp_fmadd_ps(_v2, _tmp135b, _tmp135a))); + __m128 _out03 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v4, _tmp135c, _mm_comp_fmadd_ps(_v8, _tmp135b, _tmp135a))); + __m128 _out05 = _mm_add_ps(_bias0, _mm_add_ps(_mm_add_ps(_tmp07, _tmp135a), _mm_comp_fmadd_ps(_v32, _tmp135b, _tmp135c))); + _mm_store_ps(output0 + 4, _out01); + _mm_store_ps(output0 + 4 * 3, _out03); + _mm_store_ps(output0 + 4 * 5, _out05); + + output0 += outw * 4; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_pack4_sse(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[6][6] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[6][6][4]; + + __m128 _vm5 = _mm_set1_ps(-5.f); + __m128 _vm4 = _mm_set1_ps(-4.f); + __m128 _v4 = _mm_set1_ps(4.f); + __m128 _vm2 = _mm_set1_ps(-2.f); + __m128 _v2 = _mm_set1_ps(2.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + __m128 _r00 = _mm_load_ps(r0); + __m128 _r01 = _mm_load_ps(r0 + 4); + __m128 _r02 = _mm_load_ps(r0 + 4 * 2); + __m128 _r03 = _mm_load_ps(r0 + 4 * 3); + __m128 _r04 = _mm_load_ps(r0 + 4 * 4); + __m128 _r05 = _mm_load_ps(r0 + 4 * 5); + + __m128 _tmp0m = _mm_comp_fmadd_ps(_vm5, _r02, _mm_comp_fmadd_ps(_v4, _r00, _r04)); + __m128 _tmp1m = _mm_comp_fmadd_ps(_vm4, _mm_add_ps(_r01, _r02), _mm_add_ps(_r04, _r03)); + __m128 _tmp2m = _mm_comp_fmadd_ps(_v4, _mm_sub_ps(_r01, _r02), _mm_sub_ps(_r04, _r03)); + __m128 _tmp3m = _mm_comp_fmadd_ps(_vm2, _mm_sub_ps(_r01, _r03), _mm_sub_ps(_r04, _r02)); + __m128 _tmp4m = _mm_comp_fmadd_ps(_v2, _mm_sub_ps(_r01, _r03), _mm_sub_ps(_r04, _r02)); + __m128 _tmp5m = _mm_comp_fmadd_ps(_vm5, _r03, _mm_comp_fmadd_ps(_v4, _r01, _r05)); + + _mm_store_ps(tmp[0][m], _tmp0m); + _mm_store_ps(tmp[1][m], _tmp1m); + _mm_store_ps(tmp[2][m], _tmp2m); + _mm_store_ps(tmp[3][m], _tmp3m); + _mm_store_ps(tmp[4][m], _tmp4m); + _mm_store_ps(tmp[5][m], _tmp5m); + + r0 += w * 4; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 4; + float* r0_tm_1 = r0_tm_0 + tiles * 4; + float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2; + float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3; + float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4; + float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5; + + for (int m = 0; m < 6; m++) + { + __m128 _tmp00 = _mm_load_ps(tmp[m][0]); + __m128 _tmp01 = _mm_load_ps(tmp[m][1]); + __m128 _tmp02 = _mm_load_ps(tmp[m][2]); + __m128 _tmp03 = _mm_load_ps(tmp[m][3]); + __m128 _tmp04 = _mm_load_ps(tmp[m][4]); + __m128 _tmp05 = _mm_load_ps(tmp[m][5]); + + __m128 _r0tm0 = _mm_comp_fmadd_ps(_vm5, _tmp02, _mm_comp_fmadd_ps(_v4, _tmp00, _tmp04)); + __m128 _r0tm1 = _mm_comp_fmadd_ps(_vm4, _mm_add_ps(_tmp01, _tmp02), _mm_add_ps(_tmp04, _tmp03)); + __m128 _r0tm2 = _mm_comp_fmadd_ps(_v4, _mm_sub_ps(_tmp01, _tmp02), _mm_sub_ps(_tmp04, _tmp03)); + __m128 _r0tm3 = _mm_comp_fmadd_ps(_vm2, _mm_sub_ps(_tmp01, _tmp03), _mm_sub_ps(_tmp04, _tmp02)); + __m128 _r0tm4 = _mm_comp_fmadd_ps(_v2, _mm_sub_ps(_tmp01, _tmp03), _mm_sub_ps(_tmp04, _tmp02)); + __m128 _r0tm5 = _mm_comp_fmadd_ps(_vm5, _tmp03, _mm_comp_fmadd_ps(_v4, _tmp01, _tmp05)); + + _mm_store_ps(r0_tm_0, _r0tm0); + _mm_store_ps(r0_tm_1, _r0tm1); + _mm_store_ps(r0_tm_2, _r0tm2); + _mm_store_ps(r0_tm_3, _r0tm3); + _mm_store_ps(r0_tm_4, _r0tm4); + _mm_store_ps(r0_tm_5, _r0tm5); + + r0_tm_0 += tiles * 4 * 6; + r0_tm_1 += tiles * 4 * 6; + r0_tm_2 += tiles * 4 * 6; + r0_tm_3 += tiles * 4 * 6; + r0_tm_4 += tiles * 4 * 6; + r0_tm_5 += tiles * 4 * 6; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_pack4_sse(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + __m128 _bias0 = biasptr ? _mm_loadu_ps(biasptr + p * 4) : _mm_setzero_ps(); + +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((aligned(16))) +#endif + float tmp[4][6][4]; + + __m128 _v2 = _mm_set1_ps(2.f); + __m128 _v4 = _mm_set1_ps(4.f); + __m128 _v8 = _mm_set1_ps(8.f); + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 4; + const float* output0_tm_1 = output0_tm_0 + tiles * 4; + const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5; + + float* output0 = out0.row(i * 4) + (j * 4) * 4; + + for (int m = 0; m < 6; m++) + { + __m128 _out0tm0 = _mm_load_ps(output0_tm_0); + __m128 _out0tm1 = _mm_load_ps(output0_tm_1); + __m128 _out0tm2 = _mm_load_ps(output0_tm_2); + __m128 _out0tm3 = _mm_load_ps(output0_tm_3); + __m128 _out0tm4 = _mm_load_ps(output0_tm_4); + __m128 _out0tm5 = _mm_load_ps(output0_tm_5); + + __m128 _tmp02a = _mm_add_ps(_out0tm1, _out0tm2); + __m128 _tmp13a = _mm_sub_ps(_out0tm1, _out0tm2); + + __m128 _tmp02b = _mm_add_ps(_out0tm3, _out0tm4); + __m128 _tmp13b = _mm_sub_ps(_out0tm3, _out0tm4); + + __m128 _tmp0m = _mm_add_ps(_mm_add_ps(_out0tm0, _tmp02a), _tmp02b); + __m128 _tmp1m = _mm_comp_fmadd_ps(_v2, _tmp13b, _tmp13a); + __m128 _tmp2m = _mm_comp_fmadd_ps(_v4, _tmp02b, _tmp02a); + __m128 _tmp3m = _mm_comp_fmadd_ps(_v8, _tmp13b, _mm_add_ps(_out0tm5, _tmp13a)); + + _mm_store_ps(tmp[0][m], _tmp0m); + _mm_store_ps(tmp[1][m], _tmp1m); + _mm_store_ps(tmp[2][m], _tmp2m); + _mm_store_ps(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 4 * 6; + output0_tm_1 += tiles * 4 * 6; + output0_tm_2 += tiles * 4 * 6; + output0_tm_3 += tiles * 4 * 6; + output0_tm_4 += tiles * 4 * 6; + output0_tm_5 += tiles * 4 * 6; + } + + for (int m = 0; m < 4; m++) + { + __m128 _tmp00 = _mm_load_ps(tmp[m][0]); + __m128 _tmp01 = _mm_load_ps(tmp[m][1]); + __m128 _tmp02 = _mm_load_ps(tmp[m][2]); + __m128 _tmp03 = _mm_load_ps(tmp[m][3]); + __m128 _tmp04 = _mm_load_ps(tmp[m][4]); + __m128 _tmp05 = _mm_load_ps(tmp[m][5]); + + __m128 _tmp02a = _mm_add_ps(_tmp01, _tmp02); + __m128 _tmp13a = _mm_sub_ps(_tmp01, _tmp02); + + __m128 _tmp02b = _mm_add_ps(_tmp03, _tmp04); + __m128 _tmp13b = _mm_sub_ps(_tmp03, _tmp04); + + __m128 _out00 = _mm_add_ps(_bias0, _mm_add_ps(_mm_add_ps(_tmp00, _tmp02a), _tmp02b)); + __m128 _out01 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v2, _tmp13b, _tmp13a)); + __m128 _out02 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v4, _tmp02b, _tmp02a)); + __m128 _out03 = _mm_add_ps(_bias0, _mm_comp_fmadd_ps(_v8, _tmp13b, _mm_add_ps(_tmp05, _tmp13a))); + + _mm_store_ps(output0, _out00); + _mm_store_ps(output0 + 4, _out01); + _mm_store_ps(output0 + 4 * 2, _out02); + _mm_store_ps(output0 + 4 * 3, _out03); + + output0 += outw * 4; + } + } + } + } +} diff --git a/src/layer/x86/convolution_winograd_transform_pack8.h b/src/layer/x86/convolution_winograd_transform_pack8.h new file mode 100644 index 00000000000..bce0f1e7562 --- /dev/null +++ b/src/layer/x86/convolution_winograd_transform_pack8.h @@ -0,0 +1,555 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv3x3s1_winograd64_transform_input_pack8_avx(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 6; + const int h_tiles = (h - 2) / 6; + const int tiles = w_tiles * h_tiles; + + // const float itm[8][8] = { + // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f}, + // + // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f}, + // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f}, + // + // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f}, + // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f}, + // + // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f}, + // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f}, + // + // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f} + // }; + + // 0 = r00 - r06 + (r04 - r02) * 5.25 + // 7 = r07 - r01 + (r03 - r05) * 5.25 + + // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05) + // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05) + + // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2) + // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2) + + // reuse r04 * 1.25 + // reuse r03 * 2.5 + // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5) + // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[8][8][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 6) + (j * 6) * 8; + + for (int m = 0; m < 8; m++) + { + __m256 _r00 = _mm256_load_ps(r0); + __m256 _r01 = _mm256_load_ps(r0 + 8); + __m256 _r02 = _mm256_load_ps(r0 + 16); + __m256 _r03 = _mm256_load_ps(r0 + 24); + __m256 _r04 = _mm256_load_ps(r0 + 32); + __m256 _r05 = _mm256_load_ps(r0 + 40); + __m256 _r06 = _mm256_load_ps(r0 + 48); + __m256 _r07 = _mm256_load_ps(r0 + 56); + + __m256 _tmp0m = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_r04, _r02), _mm256_sub_ps(_r00, _r06)); + __m256 _tmp7m = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_r03, _r05), _mm256_sub_ps(_r07, _r01)); + _mm256_store_ps(tmp[0][m], _tmp0m); + _mm256_store_ps(tmp[7][m], _tmp7m); + + __m256 _tmp12a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _r04, _mm256_add_ps(_r02, _r06)); + __m256 _tmp12b = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _r03, _mm256_add_ps(_r01, _r05)); + + __m256 _tmp1m = _mm256_add_ps(_tmp12a, _tmp12b); + __m256 _tmp2m = _mm256_sub_ps(_tmp12a, _tmp12b); + _mm256_store_ps(tmp[1][m], _tmp1m); + _mm256_store_ps(tmp[2][m], _tmp2m); + + __m256 _tmp34a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _r04, _mm256_comp_fmadd_ps(_mm256_set1_ps(0.25f), _r02, _r06)); + __m256 _tmp34b = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _r05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _r03, _mm256_mul_ps(_r01, _mm256_set1_ps(0.5f)))); + + __m256 _tmp3m = _mm256_add_ps(_tmp34a, _tmp34b); + __m256 _tmp4m = _mm256_sub_ps(_tmp34a, _tmp34b); + _mm256_store_ps(tmp[3][m], _tmp3m); + _mm256_store_ps(tmp[4][m], _tmp4m); + + __m256 _tmp56a = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _r04, _r02), _r06); + __m256 _tmp56b = _mm256_comp_fmadd_ps(_mm256_set1_ps(0.5f), _r05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _r03, _mm256_mul_ps(_r01, _mm256_set1_ps(2.f)))); + + __m256 _tmp5m = _mm256_add_ps(_tmp56a, _tmp56b); + __m256 _tmp6m = _mm256_sub_ps(_tmp56a, _tmp56b); + _mm256_store_ps(tmp[5][m], _tmp5m); + _mm256_store_ps(tmp[6][m], _tmp6m); + + r0 += w * 8; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 8; + float* r0_tm_1 = r0_tm_0 + tiles * 8; + float* r0_tm_2 = r0_tm_0 + tiles * 16; + float* r0_tm_3 = r0_tm_0 + tiles * 24; + float* r0_tm_4 = r0_tm_0 + tiles * 32; + float* r0_tm_5 = r0_tm_0 + tiles * 40; + float* r0_tm_6 = r0_tm_0 + tiles * 48; + float* r0_tm_7 = r0_tm_0 + tiles * 56; + + for (int m = 0; m < 8; m++) + { + __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); + __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); + __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); + __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); + __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); + __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); + __m256 _tmp06 = _mm256_load_ps(tmp[m][6]); + __m256 _tmp07 = _mm256_load_ps(tmp[m][7]); + + __m256 _r0tm0 = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_tmp04, _tmp02), _mm256_sub_ps(_tmp00, _tmp06)); + __m256 _r0tm7 = _mm256_comp_fmadd_ps(_mm256_set1_ps(5.25f), _mm256_sub_ps(_tmp03, _tmp05), _mm256_sub_ps(_tmp07, _tmp01)); + + __m256 _tmp12a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _tmp04, _mm256_add_ps(_tmp02, _tmp06)); + __m256 _tmp12b = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.25f), _tmp03, _mm256_add_ps(_tmp01, _tmp05)); + + __m256 _r0tm1 = _mm256_add_ps(_tmp12a, _tmp12b); + __m256 _r0tm2 = _mm256_sub_ps(_tmp12a, _tmp12b); + + __m256 _tmp34a = _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _tmp04, _mm256_comp_fmadd_ps(_mm256_set1_ps(0.25f), _tmp02, _tmp06)); + __m256 _tmp34b = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _tmp03, _mm256_mul_ps(_tmp01, _mm256_set1_ps(0.5f)))); + + __m256 _r0tm3 = _mm256_add_ps(_tmp34a, _tmp34b); + __m256 _r0tm4 = _mm256_sub_ps(_tmp34a, _tmp34b); + + __m256 _tmp56a = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_comp_fmadd_ps(_mm256_set1_ps(-1.25f), _tmp04, _tmp02), _tmp06); + __m256 _tmp56b = _mm256_comp_fmadd_ps(_mm256_set1_ps(0.5f), _tmp05, _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.5f), _tmp03, _mm256_mul_ps(_tmp01, _mm256_set1_ps(2.f)))); + + __m256 _r0tm5 = _mm256_add_ps(_tmp56a, _tmp56b); + __m256 _r0tm6 = _mm256_sub_ps(_tmp56a, _tmp56b); + + _mm256_store_ps(r0_tm_0, _r0tm0); + _mm256_store_ps(r0_tm_1, _r0tm1); + _mm256_store_ps(r0_tm_2, _r0tm2); + _mm256_store_ps(r0_tm_3, _r0tm3); + _mm256_store_ps(r0_tm_4, _r0tm4); + _mm256_store_ps(r0_tm_5, _r0tm5); + _mm256_store_ps(r0_tm_6, _r0tm6); + _mm256_store_ps(r0_tm_7, _r0tm7); + + r0_tm_0 += tiles * 64; + r0_tm_1 += tiles * 64; + r0_tm_2 += tiles * 64; + r0_tm_3 += tiles * 64; + r0_tm_4 += tiles * 64; + r0_tm_5 += tiles * 64; + r0_tm_6 += tiles * 64; + r0_tm_7 += tiles * 64; + } + } + } + } +} + +static void conv3x3s1_winograd64_transform_output_pack8_avx(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 6; + const int h_tiles = outh / 6; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[6][8] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f} + // }; + + // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32 + // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16 + // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8 + // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4 + // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2 + // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6) + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + __m256 _bias0 = biasptr ? _mm256_loadu_ps(biasptr + p * 8) : _mm256_setzero_ps(); + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[6][8][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 8; + const float* output0_tm_1 = output0_tm_0 + tiles * 8; + const float* output0_tm_2 = output0_tm_0 + tiles * 16; + const float* output0_tm_3 = output0_tm_0 + tiles * 24; + const float* output0_tm_4 = output0_tm_0 + tiles * 32; + const float* output0_tm_5 = output0_tm_0 + tiles * 40; + const float* output0_tm_6 = output0_tm_0 + tiles * 48; + const float* output0_tm_7 = output0_tm_0 + tiles * 56; + + float* output0 = out0.row(i * 6) + (j * 6) * 8; + + for (int m = 0; m < 8; m++) + { + __m256 _out0tm0 = _mm256_load_ps(output0_tm_0); + __m256 _out0tm1 = _mm256_load_ps(output0_tm_1); + __m256 _out0tm2 = _mm256_load_ps(output0_tm_2); + __m256 _out0tm3 = _mm256_load_ps(output0_tm_3); + __m256 _out0tm4 = _mm256_load_ps(output0_tm_4); + __m256 _out0tm5 = _mm256_load_ps(output0_tm_5); + __m256 _out0tm6 = _mm256_load_ps(output0_tm_6); + __m256 _out0tm7 = _mm256_load_ps(output0_tm_7); + + __m256 _tmp024a = _mm256_add_ps(_out0tm1, _out0tm2); + __m256 _tmp135a = _mm256_sub_ps(_out0tm1, _out0tm2); + + __m256 _tmp024b = _mm256_add_ps(_out0tm3, _out0tm4); + __m256 _tmp135b = _mm256_sub_ps(_out0tm3, _out0tm4); + + __m256 _tmp024c = _mm256_add_ps(_out0tm5, _out0tm6); + __m256 _tmp135c = _mm256_sub_ps(_out0tm5, _out0tm6); + + __m256 _tmp0m = _mm256_add_ps(_mm256_add_ps(_out0tm0, _tmp024a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp024c, _tmp024b)); + __m256 _tmp2m = _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp024b, _tmp024a)); + __m256 _tmp4m = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp024b, _tmp024a)); + _mm256_store_ps(tmp[0][m], _tmp0m); + _mm256_store_ps(tmp[2][m], _tmp2m); + _mm256_store_ps(tmp[4][m], _tmp4m); + + __m256 _tmp1m = _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp135b, _tmp135a)); + __m256 _tmp3m = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp135b, _tmp135a)); + __m256 _tmp5m = _mm256_add_ps(_mm256_add_ps(_out0tm7, _tmp135a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp135b, _tmp135c)); + _mm256_store_ps(tmp[1][m], _tmp1m); + _mm256_store_ps(tmp[3][m], _tmp3m); + _mm256_store_ps(tmp[5][m], _tmp5m); + + output0_tm_0 += tiles * 64; + output0_tm_1 += tiles * 64; + output0_tm_2 += tiles * 64; + output0_tm_3 += tiles * 64; + output0_tm_4 += tiles * 64; + output0_tm_5 += tiles * 64; + output0_tm_6 += tiles * 64; + output0_tm_7 += tiles * 64; + } + + for (int m = 0; m < 6; m++) + { + __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); + __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); + __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); + __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); + __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); + __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); + __m256 _tmp06 = _mm256_load_ps(tmp[m][6]); + __m256 _tmp07 = _mm256_load_ps(tmp[m][7]); + + __m256 _tmp024a = _mm256_add_ps(_tmp01, _tmp02); + __m256 _tmp135a = _mm256_sub_ps(_tmp01, _tmp02); + + __m256 _tmp024b = _mm256_add_ps(_tmp03, _tmp04); + __m256 _tmp135b = _mm256_sub_ps(_tmp03, _tmp04); + + __m256 _tmp024c = _mm256_add_ps(_tmp05, _tmp06); + __m256 _tmp135c = _mm256_sub_ps(_tmp05, _tmp06); + + __m256 _out00 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp00, _tmp024a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp024c, _tmp024b))); + __m256 _out02 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp024b, _tmp024a))); + __m256 _out04 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp024c, _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp024b, _tmp024a))); + _mm256_store_ps(output0, _out00); + _mm256_store_ps(output0 + 16, _out02); + _mm256_store_ps(output0 + 32, _out04); + + __m256 _out01 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(16.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp135b, _tmp135a))); + __m256 _out03 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp135c, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp135b, _tmp135a))); + __m256 _out05 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp07, _tmp135a), _mm256_comp_fmadd_ps(_mm256_set1_ps(32.f), _tmp135b, _tmp135c))); + _mm256_store_ps(output0 + 8, _out01); + _mm256_store_ps(output0 + 24, _out03); + _mm256_store_ps(output0 + 40, _out05); + + output0 += outw * 8; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_input_pack8_avx(const Mat& bottom_blob, Mat& bottom_blob_tm, const Option& opt) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + + const int w_tiles = (w - 2) / 4; + const int h_tiles = (h - 2) / 4; + const int tiles = w_tiles * h_tiles; + + // const float itm[4][4] = { + // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f}, + // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f}, + // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f}, + // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f}, + // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f} + // }; + + // 0 = 4 * r00 - 5 * r02 + r04 + // 1 = -4 * (r01 + r02) + r04 + r03 + // 2 = 4 * (r01 - r02) + r04 - r03 + // 3 = -2 * (r01 - r03) + r04 - r02 + // 4 = 2 * (r01 - r03) + r04 - r02 + // 5 = 4 * r01 - 5 * r03 + r05 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < inch; q++) + { + const Mat img0 = bottom_blob.channel(q); + Mat img0_tm = bottom_blob_tm.channel(q); + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[6][6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* r0 = img0.row(i * 4) + (j * 4) * 8; + + for (int m = 0; m < 6; m++) + { + __m256 _r00 = _mm256_load_ps(r0); + __m256 _r01 = _mm256_load_ps(r0 + 8); + __m256 _r02 = _mm256_load_ps(r0 + 8 * 2); + __m256 _r03 = _mm256_load_ps(r0 + 8 * 3); + __m256 _r04 = _mm256_load_ps(r0 + 8 * 4); + __m256 _r05 = _mm256_load_ps(r0 + 8 * 5); + + __m256 _tmp0m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _r02, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _r00, _r04)); + __m256 _tmp1m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.f), _mm256_add_ps(_r01, _r02), _mm256_add_ps(_r04, _r03)); + __m256 _tmp2m = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_sub_ps(_r01, _r02), _mm256_sub_ps(_r04, _r03)); + __m256 _tmp3m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.f), _mm256_sub_ps(_r01, _r03), _mm256_sub_ps(_r04, _r02)); + __m256 _tmp4m = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _mm256_sub_ps(_r01, _r03), _mm256_sub_ps(_r04, _r02)); + __m256 _tmp5m = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _r03, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _r01, _r05)); + + _mm256_store_ps(tmp[0][m], _tmp0m); + _mm256_store_ps(tmp[1][m], _tmp1m); + _mm256_store_ps(tmp[2][m], _tmp2m); + _mm256_store_ps(tmp[3][m], _tmp3m); + _mm256_store_ps(tmp[4][m], _tmp4m); + _mm256_store_ps(tmp[5][m], _tmp5m); + + r0 += w * 8; + } + + float* r0_tm_0 = (float*)img0_tm + (i * w_tiles + j) * 8; + float* r0_tm_1 = r0_tm_0 + tiles * 8; + float* r0_tm_2 = r0_tm_0 + tiles * 8 * 2; + float* r0_tm_3 = r0_tm_0 + tiles * 8 * 3; + float* r0_tm_4 = r0_tm_0 + tiles * 8 * 4; + float* r0_tm_5 = r0_tm_0 + tiles * 8 * 5; + + for (int m = 0; m < 6; m++) + { + __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); + __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); + __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); + __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); + __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); + __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); + + __m256 _r0tm0 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _tmp02, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp00, _tmp04)); + __m256 _r0tm1 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-4.f), _mm256_add_ps(_tmp01, _tmp02), _mm256_add_ps(_tmp04, _tmp03)); + __m256 _r0tm2 = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _mm256_sub_ps(_tmp01, _tmp02), _mm256_sub_ps(_tmp04, _tmp03)); + __m256 _r0tm3 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-2.f), _mm256_sub_ps(_tmp01, _tmp03), _mm256_sub_ps(_tmp04, _tmp02)); + __m256 _r0tm4 = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _mm256_sub_ps(_tmp01, _tmp03), _mm256_sub_ps(_tmp04, _tmp02)); + __m256 _r0tm5 = _mm256_comp_fmadd_ps(_mm256_set1_ps(-5.f), _tmp03, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp01, _tmp05)); + + _mm256_store_ps(r0_tm_0, _r0tm0); + _mm256_store_ps(r0_tm_1, _r0tm1); + _mm256_store_ps(r0_tm_2, _r0tm2); + _mm256_store_ps(r0_tm_3, _r0tm3); + _mm256_store_ps(r0_tm_4, _r0tm4); + _mm256_store_ps(r0_tm_5, _r0tm5); + + r0_tm_0 += tiles * 8 * 6; + r0_tm_1 += tiles * 8 * 6; + r0_tm_2 += tiles * 8 * 6; + r0_tm_3 += tiles * 8 * 6; + r0_tm_4 += tiles * 8 * 6; + r0_tm_5 += tiles * 8 * 6; + } + } + } + } +} + +static void conv3x3s1_winograd42_transform_output_pack8_avx(const Mat& top_blob_tm, Mat& top_blob, const Mat& bias, const Option& opt) +{ + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + + const int w_tiles = outw / 4; + const int h_tiles = outh / 4; + const int tiles = w_tiles * h_tiles; + + const float* biasptr = bias; + + // const float otm[4][6] = { + // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f}, + // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f}, + // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f} + // }; + + // 0 = r00 + (r01 + r02) + (r03 + r04) + // 1 = (r01 - r02) + (r03 - r04) * 2 + // 2 = (r01 + r02) + (r03 + r04) * 4 + // 3 = r05 + (r01 - r02) + (r03 - r04) * 8 + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + const Mat out0_tm = top_blob_tm.channel(p); + Mat out0 = top_blob.channel(p); + + __m256 _bias0 = biasptr ? _mm256_loadu_ps(biasptr + p * 8) : _mm256_setzero_ps(); + +#ifdef _MSC_VER + __declspec(align(32)) +#else + __attribute__((aligned(32))) +#endif + float tmp[4][6][8]; + + // tile + for (int i = 0; i < h_tiles; i++) + { + for (int j = 0; j < w_tiles; j++) + { + const float* output0_tm_0 = (const float*)out0_tm + (i * w_tiles + j) * 8; + const float* output0_tm_1 = output0_tm_0 + tiles * 8; + const float* output0_tm_2 = output0_tm_0 + tiles * 8 * 2; + const float* output0_tm_3 = output0_tm_0 + tiles * 8 * 3; + const float* output0_tm_4 = output0_tm_0 + tiles * 8 * 4; + const float* output0_tm_5 = output0_tm_0 + tiles * 8 * 5; + + float* output0 = out0.row(i * 4) + (j * 4) * 8; + + for (int m = 0; m < 6; m++) + { + __m256 _out0tm0 = _mm256_load_ps(output0_tm_0); + __m256 _out0tm1 = _mm256_load_ps(output0_tm_1); + __m256 _out0tm2 = _mm256_load_ps(output0_tm_2); + __m256 _out0tm3 = _mm256_load_ps(output0_tm_3); + __m256 _out0tm4 = _mm256_load_ps(output0_tm_4); + __m256 _out0tm5 = _mm256_load_ps(output0_tm_5); + + __m256 _tmp02a = _mm256_add_ps(_out0tm1, _out0tm2); + __m256 _tmp13a = _mm256_sub_ps(_out0tm1, _out0tm2); + + __m256 _tmp02b = _mm256_add_ps(_out0tm3, _out0tm4); + __m256 _tmp13b = _mm256_sub_ps(_out0tm3, _out0tm4); + + __m256 _tmp0m = _mm256_add_ps(_mm256_add_ps(_out0tm0, _tmp02a), _tmp02b); + __m256 _tmp1m = _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp13b, _tmp13a); + __m256 _tmp2m = _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp02b, _tmp02a); + __m256 _tmp3m = _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp13b, _mm256_add_ps(_out0tm5, _tmp13a)); + + _mm256_store_ps(tmp[0][m], _tmp0m); + _mm256_store_ps(tmp[1][m], _tmp1m); + _mm256_store_ps(tmp[2][m], _tmp2m); + _mm256_store_ps(tmp[3][m], _tmp3m); + + output0_tm_0 += tiles * 8 * 6; + output0_tm_1 += tiles * 8 * 6; + output0_tm_2 += tiles * 8 * 6; + output0_tm_3 += tiles * 8 * 6; + output0_tm_4 += tiles * 8 * 6; + output0_tm_5 += tiles * 8 * 6; + } + + for (int m = 0; m < 4; m++) + { + __m256 _tmp00 = _mm256_load_ps(tmp[m][0]); + __m256 _tmp01 = _mm256_load_ps(tmp[m][1]); + __m256 _tmp02 = _mm256_load_ps(tmp[m][2]); + __m256 _tmp03 = _mm256_load_ps(tmp[m][3]); + __m256 _tmp04 = _mm256_load_ps(tmp[m][4]); + __m256 _tmp05 = _mm256_load_ps(tmp[m][5]); + + __m256 _tmp02a = _mm256_add_ps(_tmp01, _tmp02); + __m256 _tmp13a = _mm256_sub_ps(_tmp01, _tmp02); + + __m256 _tmp02b = _mm256_add_ps(_tmp03, _tmp04); + __m256 _tmp13b = _mm256_sub_ps(_tmp03, _tmp04); + + __m256 _out00 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp00, _tmp02a), _tmp02b)); + __m256 _out01 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(2.f), _tmp13b, _tmp13a)); + __m256 _out02 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(4.f), _tmp02b, _tmp02a)); + __m256 _out03 = _mm256_add_ps(_bias0, _mm256_comp_fmadd_ps(_mm256_set1_ps(8.f), _tmp13b, _mm256_add_ps(_tmp05, _tmp13a))); + + _mm256_store_ps(output0, _out00); + _mm256_store_ps(output0 + 8, _out01); + _mm256_store_ps(output0 + 8 * 2, _out02); + _mm256_store_ps(output0 + 8 * 3, _out03); + + output0 += outw * 8; + } + } + } + } +} diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index a01a95a77e0..a0c71f4831a 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -33,6 +33,7 @@ namespace ncnn { #include "convolution_sgemm.h" +#include "convolution_winograd_transform.h" #include "convolution_1x1.h" #include "convolution_3x3.h" #include "convolution_5x5.h" @@ -52,6 +53,7 @@ namespace ncnn { #include "convolution_sgemm_pack4.h" #include "convolution_sgemm_pack1to4.h" #include "convolution_sgemm_pack4to1.h" +#include "convolution_winograd_transform_pack4.h" #include "convolution_1x1_pack4.h" #include "convolution_1x1_pack1to4.h" #include "convolution_1x1_pack4to1.h" @@ -87,6 +89,7 @@ namespace ncnn { #include "convolution_sgemm_pack1to8.h" #include "convolution_sgemm_pack8to4.h" #include "convolution_sgemm_pack8to1.h" +#include "convolution_winograd_transform_pack8.h" #include "convolution_1x1_pack8.h" #include "convolution_1x1_pack4to8.h" #include "convolution_1x1_pack1to8.h" @@ -111,6 +114,7 @@ namespace ncnn { #include "convolution_sgemm_pack1to16.h" #include "convolution_sgemm_pack16to8.h" #include "convolution_sgemm_pack16to4.h" +#include "convolution_winograd_transform_pack16.h" #include "convolution_1x1_pack16.h" #include "convolution_1x1_pack8to16.h" #include "convolution_1x1_pack1to16.h"