From 4f0eacb3cae8d63ee071a16a4730dac393906866 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 11 Apr 2022 10:34:20 +0800 Subject: [PATCH] convolution sgemm pack4to16 --- src/layer/x86/convolution_1x1_pack4to16.h | 66 ++++ src/layer/x86/convolution_sgemm_pack4to16.h | 317 ++++++++++++++++++++ src/layer/x86/convolution_x86.cpp | 43 +++ 3 files changed, 426 insertions(+) create mode 100644 src/layer/x86/convolution_1x1_pack4to16.h create mode 100644 src/layer/x86/convolution_sgemm_pack4to16.h diff --git a/src/layer/x86/convolution_1x1_pack4to16.h b/src/layer/x86/convolution_1x1_pack4to16.h new file mode 100644 index 00000000000..28f81377484 --- /dev/null +++ b/src/layer/x86/convolution_1x1_pack4to16.h @@ -0,0 +1,66 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void conv1x1s1_sgemm_pack4to16_avx512(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt) +{ + int w = bottom_blob.w; + int h = bottom_blob.h; + const int size = w * h; + + Mat bottom_im2col = bottom_blob; + bottom_im2col.w = size; + bottom_im2col.h = 1; + + im2col_sgemm_pack4to16_avx512(bottom_im2col, top_blob, kernel, _bias, opt); +} + +static void conv1x1s2_sgemm_pack4to16_avx512(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt) +{ + int w = bottom_blob.w; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; + int elempack = bottom_blob.elempack; + + int outw = top_blob.w; + int outh = top_blob.h; + + const int tailstep = (w - 2 * outw + w) * 4; + + Mat bottom_blob_shrinked; + bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < channels; p++) + { + const float* r0 = bottom_blob.channel(p); + float* outptr = bottom_blob_shrinked.channel(p); + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j < outw; j++) + { + __m128 _v = _mm_load_ps(r0); + _mm_store_ps(outptr, _v); + + r0 += 8; + outptr += 4; + } + + r0 += tailstep; + } + } + + conv1x1s1_sgemm_pack4to16_avx512(bottom_blob_shrinked, top_blob, kernel, _bias, opt); +} diff --git a/src/layer/x86/convolution_sgemm_pack4to16.h b/src/layer/x86/convolution_sgemm_pack4to16.h new file mode 100644 index 00000000000..879f9e48dd6 --- /dev/null +++ b/src/layer/x86/convolution_sgemm_pack4to16.h @@ -0,0 +1,317 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +static void im2col_sgemm_pack4to16_avx512(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt) +{ + // Mat bottom_im2col(size, maxk, inch, 16u, 4, opt.workspace_allocator); + + const int size = bottom_im2col.w; + const int maxk = bottom_im2col.h; + const int inch = bottom_im2col.c; + + const int outch = top_blob.c; + + const float* bias = _bias; + + // permute + Mat tmp; + if (size >= 16) + tmp.create(16 * maxk, inch, size / 16 + size % 16, 16u, 4, opt.workspace_allocator); + else + tmp.create(maxk, inch, size, 16u, 4, opt.workspace_allocator); + { + int nn_size = size >> 4; + int remain_size_start = 0; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int ii = 0; ii < nn_size; ii++) + { + int i = remain_size_start + ii * 16; + + float* tmpptr = tmp.channel(i / 16); + + for (int q = 0; q < inch; q++) + { + const float* img0 = (const float*)bottom_im2col.channel(q) + i * 4; + + for (int k = 0; k < maxk; k++) + { + // transpose 4x16 + __m128 _r0 = _mm_load_ps(img0); + __m128 _r1 = _mm_load_ps(img0 + 4); + __m128 _r2 = _mm_load_ps(img0 + 4 * 2); + __m128 _r3 = _mm_load_ps(img0 + 4 * 3); + __m128 _r4 = _mm_load_ps(img0 + 4 * 4); + __m128 _r5 = _mm_load_ps(img0 + 4 * 5); + __m128 _r6 = _mm_load_ps(img0 + 4 * 6); + __m128 _r7 = _mm_load_ps(img0 + 4 * 7); + __m128 _r8 = _mm_load_ps(img0 + 4 * 8); + __m128 _r9 = _mm_load_ps(img0 + 4 * 9); + __m128 _ra = _mm_load_ps(img0 + 4 * 10); + __m128 _rb = _mm_load_ps(img0 + 4 * 11); + __m128 _rc = _mm_load_ps(img0 + 4 * 12); + __m128 _rd = _mm_load_ps(img0 + 4 * 13); + __m128 _re = _mm_load_ps(img0 + 4 * 14); + __m128 _rf = _mm_load_ps(img0 + 4 * 15); + + _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); + _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); + _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); + _MM_TRANSPOSE4_PS(_rc, _rd, _re, _rf); + + _mm_store_ps(tmpptr, _r0); + _mm_store_ps(tmpptr + 4, _r4); + _mm_store_ps(tmpptr + 4 * 2, _r8); + _mm_store_ps(tmpptr + 4 * 3, _rc); + _mm_store_ps(tmpptr + 4 * 4, _r1); + _mm_store_ps(tmpptr + 4 * 5, _r5); + _mm_store_ps(tmpptr + 4 * 6, _r9); + _mm_store_ps(tmpptr + 4 * 7, _rd); + _mm_store_ps(tmpptr + 4 * 8, _r2); + _mm_store_ps(tmpptr + 4 * 9, _r6); + _mm_store_ps(tmpptr + 4 * 10, _ra); + _mm_store_ps(tmpptr + 4 * 11, _re); + _mm_store_ps(tmpptr + 4 * 12, _r3); + _mm_store_ps(tmpptr + 4 * 13, _r7); + _mm_store_ps(tmpptr + 4 * 14, _rb); + _mm_store_ps(tmpptr + 4 * 15, _rf); + + img0 += size * 4; + tmpptr += 64; + } + } + } + + remain_size_start += nn_size << 4; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = remain_size_start; i < size; i++) + { + float* tmpptr = tmp.channel(i / 16 + i % 16); + + for (int q = 0; q < inch; q++) + { + const float* img0 = (const float*)bottom_im2col.channel(q) + i * 4; + + for (int k = 0; k < maxk; k++) + { + __m128 _val = _mm_load_ps(img0); + _mm_store_ps(tmpptr, _val); + + img0 += size * 4; + tmpptr += 4; + } + } + } + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < outch; p++) + { + float* outptr0 = top_blob.channel(p); + + const float zeros[16] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + const float* biasptr = bias ? bias + p * 16 : zeros; + + int i = 0; + for (; i + 15 < size; i += 16) + { + float* tmpptr = tmp.channel(i / 16); + const float* kptr = kernel.channel(p); + + int nn = inch * maxk * 4; // inch always > 0 + + __m512 _sum0 = _mm512_loadu_ps(biasptr); + __m512 _sum1 = _sum0; + __m512 _sum2 = _sum0; + __m512 _sum3 = _sum0; + __m512 _sum4 = _sum0; + __m512 _sum5 = _sum0; + __m512 _sum6 = _sum0; + __m512 _sum7 = _sum0; + __m512 _sum8 = _sum0; + __m512 _sum9 = _sum0; + __m512 _suma = _sum0; + __m512 _sumb = _sum0; + __m512 _sumc = _sum0; + __m512 _sumd = _sum0; + __m512 _sume = _sum0; + __m512 _sumf = _sum0; + + for (int j = 0; j < nn; j++) + { + __m512 _w0 = _mm512_load_ps(kptr); + + __m512 _val0 = _mm512_set1_ps(tmpptr[0]); + __m512 _val1 = _mm512_set1_ps(tmpptr[1]); + _sum0 = _mm512_fmadd_ps(_val0, _w0, _sum0); + _sum1 = _mm512_fmadd_ps(_val1, _w0, _sum1); + __m512 _val2 = _mm512_set1_ps(tmpptr[2]); + __m512 _val3 = _mm512_set1_ps(tmpptr[3]); + _sum2 = _mm512_fmadd_ps(_val2, _w0, _sum2); + _sum3 = _mm512_fmadd_ps(_val3, _w0, _sum3); + __m512 _val4 = _mm512_set1_ps(tmpptr[4]); + __m512 _val5 = _mm512_set1_ps(tmpptr[5]); + _sum4 = _mm512_fmadd_ps(_val4, _w0, _sum4); + _sum5 = _mm512_fmadd_ps(_val5, _w0, _sum5); + __m512 _val6 = _mm512_set1_ps(tmpptr[6]); + __m512 _val7 = _mm512_set1_ps(tmpptr[7]); + _sum6 = _mm512_fmadd_ps(_val6, _w0, _sum6); + _sum7 = _mm512_fmadd_ps(_val7, _w0, _sum7); + __m512 _val8 = _mm512_set1_ps(tmpptr[8]); + __m512 _val9 = _mm512_set1_ps(tmpptr[9]); + _sum8 = _mm512_fmadd_ps(_val8, _w0, _sum8); + _sum9 = _mm512_fmadd_ps(_val9, _w0, _sum9); + __m512 _vala = _mm512_set1_ps(tmpptr[10]); + __m512 _valb = _mm512_set1_ps(tmpptr[11]); + _suma = _mm512_fmadd_ps(_vala, _w0, _suma); + _sumb = _mm512_fmadd_ps(_valb, _w0, _sumb); + __m512 _valc = _mm512_set1_ps(tmpptr[12]); + __m512 _vald = _mm512_set1_ps(tmpptr[13]); + _sumc = _mm512_fmadd_ps(_valc, _w0, _sumc); + _sumd = _mm512_fmadd_ps(_vald, _w0, _sumd); + __m512 _vale = _mm512_set1_ps(tmpptr[14]); + __m512 _valf = _mm512_set1_ps(tmpptr[15]); + _sume = _mm512_fmadd_ps(_vale, _w0, _sume); + _sumf = _mm512_fmadd_ps(_valf, _w0, _sumf); + + kptr += 16; + tmpptr += 16; + } + + _mm512_store_ps(outptr0, _sum0); + _mm512_store_ps(outptr0 + 16, _sum1); + _mm512_store_ps(outptr0 + 16 * 2, _sum2); + _mm512_store_ps(outptr0 + 16 * 3, _sum3); + _mm512_store_ps(outptr0 + 16 * 4, _sum4); + _mm512_store_ps(outptr0 + 16 * 5, _sum5); + _mm512_store_ps(outptr0 + 16 * 6, _sum6); + _mm512_store_ps(outptr0 + 16 * 7, _sum7); + _mm512_store_ps(outptr0 + 16 * 8, _sum8); + _mm512_store_ps(outptr0 + 16 * 9, _sum9); + _mm512_store_ps(outptr0 + 16 * 10, _suma); + _mm512_store_ps(outptr0 + 16 * 11, _sumb); + _mm512_store_ps(outptr0 + 16 * 12, _sumc); + _mm512_store_ps(outptr0 + 16 * 13, _sumd); + _mm512_store_ps(outptr0 + 16 * 14, _sume); + _mm512_store_ps(outptr0 + 16 * 15, _sumf); + + outptr0 += 16 * 16; + } + for (; i < size; i++) + { + float* tmpptr = tmp.channel(i / 16 + i % 16); + const float* kptr = kernel.channel(p); + + int nn = inch * maxk * 4; // inch always > 0 + + __m512 _sum0 = _mm512_loadu_ps(biasptr); + + for (int j = 0; j < nn; j++) + { + __m512 _w0 = _mm512_load_ps(kptr); + __m512 _val0 = _mm512_set1_ps(tmpptr[0]); + _sum0 = _mm512_fmadd_ps(_val0, _w0, _sum0); + + kptr += 16; + tmpptr += 1; + } + + _mm512_store_ps(outptr0, _sum0); + outptr0 += 16; + } + } +} + +static void convolution_im2col_sgemm_transform_kernel_pack4to16_avx512(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h) +{ + const int maxk = kernel_w * kernel_h; + + // interleave + // src = maxk-inch-outch + // dst = 16b-4a-maxk-inch/4a-outch/16b + Mat kernel = _kernel.reshape(maxk, inch, outch); + kernel_tm.create(16 * 4 * maxk, inch / 4, outch / 16, (size_t)4u); + + for (int q = 0; q + 15 < outch; q += 16) + { + float* g00 = kernel_tm.channel(q / 16); + + for (int p = 0; p + 3 < inch; p += 4) + { + for (int k = 0; k < maxk; k++) + { + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 16; j++) + { + const float* k00 = kernel.channel(q + j).row(p + i); + g00[0] = k00[k]; + g00++; + } + } + } + } + } +} + +static void convolution_im2col_sgemm_pack4to16_avx512(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt) +{ + int w = bottom_blob.w; + int inch = bottom_blob.c; + + int outw = top_blob.w; + int outh = top_blob.h; + const int size = outw * outh; + + const int maxk = kernel_w * kernel_h; + + // im2col + Mat bottom_im2col(size, maxk, inch, 16u, 4, opt.workspace_allocator); + { + const int gap = (w * stride_h - outw * stride_w) * 4; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int p = 0; p < inch; p++) + { + const Mat img = bottom_blob.channel(p); + float* ptr = bottom_im2col.channel(p); + + for (int u = 0; u < kernel_h; u++) + { + for (int v = 0; v < kernel_w; v++) + { + const float* sptr = img.row(dilation_h * u) + dilation_w * v * 4; + + for (int i = 0; i < outh; i++) + { + int j = 0; + for (; j < outw; j++) + { + __m128 _val = _mm_load_ps(sptr); + _mm_store_ps(ptr, _val); + + sptr += stride_w * 4; + ptr += 4; + } + + sptr += gap; + } + } + } + } + } + + im2col_sgemm_pack4to16_avx512(bottom_im2col, top_blob, kernel, _bias, opt); +} diff --git a/src/layer/x86/convolution_x86.cpp b/src/layer/x86/convolution_x86.cpp index 93718c567c0..50d19b1694e 100644 --- a/src/layer/x86/convolution_x86.cpp +++ b/src/layer/x86/convolution_x86.cpp @@ -111,6 +111,7 @@ namespace ncnn { #include "convolution_sgemm_pack16.h" #include "convolution_sgemm_pack8to16.h" +#include "convolution_sgemm_pack4to16.h" #include "convolution_sgemm_pack1to16.h" #include "convolution_sgemm_pack16to8.h" #include "convolution_sgemm_pack16to4.h" @@ -118,6 +119,7 @@ namespace ncnn { #include "convolution_winograd_transform_pack16.h" #include "convolution_1x1_pack16.h" #include "convolution_1x1_pack8to16.h" +#include "convolution_1x1_pack4to16.h" #include "convolution_1x1_pack1to16.h" #include "convolution_1x1_pack16to8.h" #include "convolution_1x1_pack16to4.h" @@ -328,6 +330,19 @@ int Convolution_x86::create_pipeline(const Option& opt) if (elempack == 4 && out_elempack == 16) { + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + convolution_im2col_sgemm_transform_kernel_pack4to16_avx512(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + convolution_im2col_sgemm_transform_kernel_pack4to16_avx512(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else if (opt.use_sgemm_convolution) + { + convolution_im2col_sgemm_transform_kernel_pack4to16_avx512(weight_data, weight_sgemm_data, num_input, num_output, kernel_w, kernel_h); + } + else { convolution_transform_kernel_packed_sse(weight_data, weight_data_packed, num_input, num_output, kernel_w, kernel_h, elempack, out_elempack); } @@ -878,6 +893,34 @@ int Convolution_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option if (elempack == 4 && out_elempack == 16) { + if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 1 && stride_h == 1) + { + conv1x1s1_sgemm_pack4to16_avx512(bottom_blob_bordered, top_blob, weight_sgemm_data, bias_data, opt); + + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + } + else if (kernel_w == 1 && kernel_h == 1 && dilation_w == 1 && dilation_h == 1 && stride_w == 2 && stride_h == 2) + { + conv1x1s2_sgemm_pack4to16_avx512(bottom_blob_bordered, top_blob, weight_sgemm_data, bias_data, opt); + + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + } + else if (opt.use_sgemm_convolution) + { + convolution_im2col_sgemm_pack4to16_avx512(bottom_blob_bordered, top_blob, weight_sgemm_data, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, opt); + + if (activation) + { + activation->forward_inplace(top_blob, opt); + } + } + else { convolution_pack4to16_avx512(bottom_blob_bordered, top_blob, weight_data_packed, bias_data, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h, activation_type, activation_params, opt); }