From 5fa6f5b25909d082f7b81176ee9f870a255b6ed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=B0=E9=98=85?= <43716063+Baiyuetribe@users.noreply.github.com> Date: Sun, 12 Jan 2025 16:21:57 +0800 Subject: [PATCH] fix 4D with d and c --- src/layer/topk.cpp | 101 +++++++++++++---------- src/layer/topk.h | 3 +- tests/test_topk.cpp | 12 +-- tools/pnnx/src/pass_ncnn/torch_topk.cpp | 7 -- tools/pnnx/tests/ncnn/test_torch_topk.py | 48 +++++++---- 5 files changed, 100 insertions(+), 71 deletions(-) diff --git a/src/layer/topk.cpp b/src/layer/topk.cpp index f7e2aa4f297..efe15e62f16 100644 --- a/src/layer/topk.cpp +++ b/src/layer/topk.cpp @@ -28,7 +28,7 @@ namespace ncnn { // }; // simplestl兼容写法 -struct TopK::CompareFunc +struct CompareFunc { bool largest; CompareFunc(bool l) @@ -43,7 +43,7 @@ struct TopK::CompareFunc } }; -void TopK::do_sort(std::vector >& vec, int k, bool sorted) const +void TopK::do_sort(std::vector >& vec) const { CompareFunc comp(largest); if (sorted) @@ -72,8 +72,6 @@ void TopK::do_sort(std::vector >& vec, int k, bool sorted) TopK::TopK() { - // one_blob_only = true; // 仅有1个输入和1个输出 - // support_inplace = true; // 是否支持原地运算,即输入和输出共享一个blob one_blob_only = false; // 只需要一个输入 blob support_inplace = false; // 是否支持原地运算 } @@ -127,7 +125,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl } // 根据sorted参数选择排序方式 - do_sort(vec, k, sorted); + do_sort(vec); // 保存结果 for (int i = 0; i < k; i++) @@ -144,7 +142,6 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl top_blob_values.create(w, k, elemsize, opt.blob_allocator); top_blob_indices.create(w, k, sizeof(int), opt.blob_allocator); - // #pragma omp parallel for for (int j = 0; j < w; j++) // 对每列进行处理 { std::vector > vec(h); @@ -154,7 +151,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl vec[i] = std::make_pair(bottom_blob.row(i)[j], i); } - do_sort(vec, k, sorted); + do_sort(vec); // 保存结果到对应列 for (int i = 0; i < k; i++) @@ -182,7 +179,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl vec[j] = std::make_pair(ptr[j], j); } - do_sort(vec, k, sorted); + do_sort(vec); for (int j = 0; j < k; j++) { @@ -213,7 +210,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl } // 排序 - do_sort(channel_values, k, sorted); + do_sort(channel_values); // 写回结果 for (int c = 0; c < k; c++) @@ -244,7 +241,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl } // 找到最大行的索引 - do_sort(row_scores, k, sorted); + do_sort(row_scores); // 保存该列的结果 for (int i = 0; i < k; i++) @@ -276,7 +273,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl vec[i] = std::make_pair(ptr[i], i); } - do_sort(vec, k, sorted); + do_sort(vec); for (int i = 0; i < k; i++) { @@ -292,53 +289,73 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl // 4D数据处理 if (axis == 0) { - // PyTorch:batch -> channel -> height -> width - // ncnn:channels -> depth -> height -> width - top_blob_values.create(w, h, k, channels, elemsize, opt.blob_allocator); - top_blob_indices.create(w, h, k, channels, sizeof(int), opt.blob_allocator); - - // 在pytorch中,假设x为torch.Size([3, 2, 6, 7]),按N维度,也就是x[0]、x[1]、x[2],对比排序,最后直接输出x[i] - // 但在ncnn中,从channels遍历后,维度d再遍历会获得2*3=6种数据。这里就卡主了,不知道怎么处理 - // need help !!! - } - else if (axis == 1) - { - // 在channel维度上进行TopK + // 在torch中d维度求topk top_blob_values.create(w, h, d, k, elemsize, opt.blob_allocator); top_blob_indices.create(w, h, d, k, sizeof(int), opt.blob_allocator); - // need help !!! + for (int z = 0; z < d; z++) + { + for (int i = 0; i < h; i++) + { + for (int j = 0; j < w; j++) + { + // 收集channel维度的值 + std::vector > channel_values(channels); + for (int c = 0; c < channels; c++) + { + const float* ptr = bottom_blob.channel(c); + int offset = z * h * w + i * w + j; + channel_values[c] = std::make_pair(ptr[offset], c); + } + + // 排序 + do_sort(channel_values); + + // 保存结果 + for (int kk = 0; kk < k; kk++) + { + float* outptr = top_blob_values.channel(kk); + int* indptr = top_blob_indices.channel(kk); + int out_offset = z * h * w + i * w + j; + outptr[out_offset] = channel_values[kk].first; + indptr[out_offset] = channel_values[kk].second; + } + } + } + } } - else if (axis == 20) + else if (axis == 1) { - // 在h维度上进行TopK - top_blob_values.create(w, k, d, channels, elemsize, opt.blob_allocator); - top_blob_indices.create(w, k, d, channels, sizeof(int), opt.blob_allocator); + // 在torch中c维度求topk + top_blob_values.create(w, h, k, channels, elemsize, opt.blob_allocator); + top_blob_indices.create(w, h, k, channels, sizeof(int), opt.blob_allocator); for (int q = 0; q < channels; q++) { const float* ptr = bottom_blob.channel(q); float* outptr = top_blob_values.channel(q); - int* indices = top_blob_indices.channel(q); + int* indptr = top_blob_indices.channel(q); - for (int z = 0; z < d; z++) + for (int i = 0; i < h; i++) { - for (int i = 0; i < w; i++) + for (int j = 0; j < w; j++) { - std::vector > row_scores(h); - for (int j = 0; j < h; j++) + // 收集当前(h,w)位置在d维度上的所有值 + std::vector > vec(d); + for (int z = 0; z < d; z++) { - int offset = (z * h + j) * w + i; - row_scores[j] = std::make_pair(ptr[offset], j); + int offset = z * h * w + i * w + j; + vec[z] = std::make_pair(ptr[offset], z); } - do_sort(row_scores, k, sorted); + do_sort(vec); - // 循环写入前 k 个值 - for (int kk = 0; kk < k; kk++) + // 保存top-k结果 + for (int z = 0; z < k; z++) { - outptr[(z * k + kk) * w + i] = row_scores[kk].first; - indices[(z * k + kk) * w + i] = row_scores[kk].second; + int offset = z * h * w + i * w + j; + outptr[offset] = vec[z].first; + indptr[offset] = vec[z].second; } } } @@ -367,7 +384,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl row_scores[j] = std::make_pair(ptr[offset], j); } - do_sort(row_scores, k, sorted); + do_sort(row_scores); // 写回结果 for (int kk = 0; kk < k; kk++) @@ -399,7 +416,7 @@ int TopK::forward(const std::vector& bottom_blobs, std::vector& top_bl row_values[j] = std::make_pair(ptr[j], j); } - do_sort(row_values, k, sorted); + do_sort(row_values); // 写回结果 for (int j = 0; j < k; j++) diff --git a/src/layer/topk.h b/src/layer/topk.h index a75c2959e42..a243a9a2a30 100644 --- a/src/layer/topk.h +++ b/src/layer/topk.h @@ -35,8 +35,7 @@ class TopK : public Layer int sorted; private: - struct CompareFunc; // 前向声明 - void do_sort(std::vector >& vec, int k, bool sorted) const; + void do_sort(std::vector >& vec) const; }; } // namespace ncnn diff --git a/tests/test_topk.cpp b/tests/test_topk.cpp index aa18baea3a2..351c6060e2c 100644 --- a/tests/test_topk.cpp +++ b/tests/test_topk.cpp @@ -18,10 +18,10 @@ static int test_topk(const ncnn::Mat& a, int k, int axis, int largest, int sorted) { ncnn::ParamDict pd; - pd.set(0, k); // k - pd.set(1, axis); // axis - pd.set(2, largest); // largest - pd.set(3, sorted); // sorted + pd.set(0, k); + pd.set(1, axis); + pd.set(2, largest); + pd.set(3, sorted); std::vector weights(0); @@ -40,8 +40,8 @@ static int test_topk(const ncnn::Mat& a, int k, int axis, int largest, int sorte static int test_topk_0() { return 0 - // || test_topk(RandomMat(3, 2, 6, 7), 1, 0, 1, 1) // axis=0暂未实现 - // || test_topk(RandomMat(3, 4, 2, 5), 2, 1, 0, 1) // axis=1暂未实现 + || test_topk(RandomMat(3, 2, 6, 7), 1, 0, 1, 1) + || test_topk(RandomMat(3, 4, 2, 5), 2, 1, 0, 1) || test_topk(RandomMat(3, 6, 4, 2), 2, 2, 1, 0) || test_topk(RandomMat(5, 3, 5, 3), 1, 3, 1, 1); } diff --git a/tools/pnnx/src/pass_ncnn/torch_topk.cpp b/tools/pnnx/src/pass_ncnn/torch_topk.cpp index fb0a0f08b02..9adf06cc464 100644 --- a/tools/pnnx/src/pass_ncnn/torch_topk.cpp +++ b/tools/pnnx/src/pass_ncnn/torch_topk.cpp @@ -52,13 +52,6 @@ pnnx.Output output 2 0 out indices op->params["1"] = dim; op->params["2"] = largest; op->params["3"] = sorted; - - // 未完成说明 - int input_rank = (int)op->inputs[0]->shape.size(); - if (input_rank == 4 && (dim == 0 || dim == 1)) - { - printf("error: 4D with dim = 0 or 1 is not supported yet\n"); - } } }; diff --git a/tools/pnnx/tests/ncnn/test_torch_topk.py b/tools/pnnx/tests/ncnn/test_torch_topk.py index 8bc3c68a300..d36af339b8e 100644 --- a/tools/pnnx/tests/ncnn/test_torch_topk.py +++ b/tools/pnnx/tests/ncnn/test_torch_topk.py @@ -27,27 +27,47 @@ def forward(self, x, y, z, d): y2, i2 = torch.topk(y, k=2, dim=1, largest=False) # 3D z1, i3 = torch.topk(z, k=2, dim=0) - z1, i4 = torch.topk(z, k=3, dim=1) - z1, i5 = torch.topk(z, k=1, dim=2) + z2, i4 = torch.topk(z, k=3, dim=1) + z3, i5 = torch.topk(z, k=1, dim=2) # 4D - # d0, i6 = torch.topk( - # d, - # k=2, - # dim=0, - # ) - # d1, i7 = torch.topk( - # d, - # k=2, - # dim=1, - # ) + d0, i6 = torch.topk( + d, + k=2, + dim=0, + ) + d1, i7 = torch.topk( + d, + k=2, + dim=1, + ) d2, i8 = torch.topk( d, k=2, dim=2, ) d3, i9 = torch.topk(d, k=2, dim=3, sorted=True) - # return x0, y1, y2, z1, i3, i4, i5, d0, d1, d2, d3, i6, i7, i8, i9 - return x0, y1, y2, i0, i1, i2, z1, i3, i4, i5, d2, d3, i8, i9 + return ( + x0, + i0, + y1, + i1, + y2, + i2, + z1, + i3, + z2, + i4, + z3, + i5, + d0, + i6, + d1, + i7, + d2, + i8, + d3, + i9, + ) def test():