From 031241e3ea58fddb297a847768abab195a48bf1e Mon Sep 17 00:00:00 2001 From: sunway Date: Sun, 26 Sep 2021 19:42:37 +0800 Subject: [PATCH 1/3] support arbitrary input dims for add/mul/relu of dnnl c_src codegen --- src/relay/backend/contrib/dnnl/codegen.cc | 33 +++++++----- src/runtime/contrib/dnnl/dnnl.cc | 64 +++++++++++++++++------ src/runtime/contrib/dnnl/dnnl_kernel.h | 5 +- 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index f0d360ae8b6d..cd1a203bed65 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -54,6 +54,15 @@ inline size_t GetShape1DSize(const Type& type) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } +inline std::string GetShapeString(std::vector shape) { + std::string v = "std::vector{"; + for (auto s : shape) { + v += std::to_string(s) + ","; + } + v += "}"; + return v; +} + std::vector Conv2d(const CallNode* call) { std::vector args; const auto* conv2d_attr = call->attrs.as(); @@ -98,12 +107,7 @@ std::vector Dense(const CallNode* call) { std::vector Relu(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); - - // Args: N, C, H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } - + args.push_back(GetShapeString(ishape)); return args; } @@ -126,12 +130,16 @@ std::vector BatchNorm(const CallNode* call) { std::vector Add(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); + args.push_back("0"); + args.push_back(GetShapeString(ishape)); + return args; +} - // Args: H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } - +std::vector Multiply(const CallNode* call) { + std::vector args; + auto ishape = GetShape(call->args[0]->checked_type()); + args.push_back("1"); + args.push_back(GetShapeString(ishape)); return args; } @@ -243,7 +251,8 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C {"nn.dense", {"dnnl_dense", Dense}}, {"nn.relu", {"dnnl_relu", Relu}}, {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, - {"add", {"dnnl_add", Add}}, + {"add", {"dnnl_binary_op", Add}}, + {"multiply", {"dnnl_binary_op", Multiply}}, }; const auto op_name = GetRef(op_node)->name; diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 19b3f796fd33..a0af9bc7f525 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -44,6 +44,32 @@ typedef struct { void** data; } DnnlPackedArgs; +inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, + memory::data_type dtype) { + using tag = memory::format_tag; + + dnnl::memory::desc data_md; + + switch (shape.size()) { + case 2: + data_md = dnnl::memory::desc({shape, dtype, tag::ab}); + break; + case 3: + data_md = dnnl::memory::desc({shape, dtype, tag::abc}); + break; + case 4: + data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); + break; + case 5: + data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); + break; + default: + assert(true); + break; + } + return data_md; +} + // Read from memory, write to handle inline void read_from_dnnl_memory(void* handle, const memory& mem) { size_t bytes = mem.get_desc().get_size(); @@ -175,16 +201,13 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) { - using tag = memory::format_tag; +extern "C" void dnnl_relu(float* data, float* out, std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); stream s(eng); - memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; - - auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + auto data_md = GenDNNLMemDescByShape(shape, dt::f32); auto data_memory = memory(data_md, eng, data); auto dst_memory = memory(data_md, eng); @@ -241,27 +264,34 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo free(weight); } -extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_, - int p_W_) { - using tag = memory::format_tag; +extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_type, + std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); stream s(eng); - memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; - - auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; - auto weight_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); - auto dst_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); + auto data_md = GenDNNLMemDescByShape(shape, dt::f32); auto data_memory = memory(data_md, eng, data); - auto weight_memory = memory(weight_md, eng, weight); - auto dst_memory = memory(dst_md, eng); + auto weight_memory = memory(data_md, eng, weight); + auto dst_memory = memory(data_md, eng); - auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + algorithm algo = algorithm::undef; + switch (algo_type) { + case 0: + algo = algorithm::binary_add; + break; + case 1: + algo = algorithm::binary_mul; + default: + assert(true); + break; + }; + + auto add_desc = binary::desc(algo, data_md, data_md, data_md); auto add_prim_desc = binary::primitive_desc(add_desc, eng); - assert(dst_md == add_prim_desc.dst_desc()); + assert(data_md == add_prim_desc.dst_desc()); auto add = binary(add_prim_desc); add.execute( diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index f5f28fccd8e7..54f007ce0d32 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -54,14 +54,13 @@ extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_); -extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); +extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector shape); extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, float* out, float* new_mean, float* new_variance, int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); -extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, - int p_h_, int p_w_); +extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo, std::vector shape); } // namespace contrib } // namespace runtime From da11a6c6c0b16b6bdce1156eee8befd4d84e06e0 Mon Sep 17 00:00:00 2001 From: sunway Date: Sun, 26 Sep 2021 21:28:43 +0800 Subject: [PATCH 2/3] fix lint --- src/relay/backend/contrib/dnnl/codegen.cc | 9 +++------ src/runtime/contrib/dnnl/dnnl.cc | 6 +++--- src/runtime/contrib/dnnl/dnnl_kernel.h | 7 +++++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index cd1a203bed65..e4a7dee9df38 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -247,12 +247,9 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C using ArgFunType = std::function(const CallNode*)>; static const std::map> op_map = { - {"nn.conv2d", {"dnnl_conv2d", Conv2d}}, - {"nn.dense", {"dnnl_dense", Dense}}, - {"nn.relu", {"dnnl_relu", Relu}}, - {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, - {"add", {"dnnl_binary_op", Add}}, - {"multiply", {"dnnl_binary_op", Multiply}}, + {"nn.conv2d", {"dnnl_conv2d", Conv2d}}, {"nn.dense", {"dnnl_dense", Dense}}, + {"nn.relu", {"dnnl_relu", Relu}}, {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, + {"add", {"dnnl_binary_op", Add}}, {"multiply", {"dnnl_binary_op", Multiply}}, }; const auto op_name = GetRef(op_node)->name; diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index a0af9bc7f525..0c81be255711 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -201,7 +201,7 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_relu(float* data, float* out, std::vector shape) { +extern "C" void dnnl_relu(float* data, float* out, std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); @@ -265,7 +265,7 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo } extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_type, - std::vector shape) { + std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); @@ -287,7 +287,7 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ default: assert(true); break; - }; + } auto add_desc = binary::desc(algo, data_md, data_md, data_md); auto add_prim_desc = binary::primitive_desc(add_desc, eng); diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index 54f007ce0d32..a29d503746ae 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -27,6 +27,8 @@ #include +#include + #include "dnnl.hpp" namespace tvm { @@ -54,13 +56,14 @@ extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_); -extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector shape); +extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector shape); extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, float* out, float* new_mean, float* new_variance, int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); -extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo, std::vector shape); +extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo, + std::vector shape); } // namespace contrib } // namespace runtime From a917fbbcbe29f0d875c0f5d99669949ddffd42b7 Mon Sep 17 00:00:00 2001 From: sunway Date: Mon, 27 Sep 2021 10:37:17 +0800 Subject: [PATCH 3/3] fix --- src/relay/backend/contrib/dnnl/codegen.cc | 11 +++++++++-- src/runtime/contrib/dnnl/dnnl.cc | 13 +++++++++---- src/runtime/contrib/dnnl/dnnl_kernel.h | 1 + 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index e4a7dee9df38..ae58c2f08e8c 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -107,6 +107,7 @@ std::vector Dense(const CallNode* call) { std::vector Relu(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); + // Args: N, C, H, W args.push_back(GetShapeString(ishape)); return args; } @@ -127,10 +128,15 @@ std::vector BatchNorm(const CallNode* call) { return args; } +// should comply with src/runtime/contrib/dnnl/dnnl.cc +#define DNNL_BINARY_ADD 0 +#define DNNL_BINARY_MUL 1 + std::vector Add(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); - args.push_back("0"); + args.push_back(std::to_string(DNNL_BINARY_ADD)); + // Args: H, W args.push_back(GetShapeString(ishape)); return args; } @@ -138,7 +144,8 @@ std::vector Add(const CallNode* call) { std::vector Multiply(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); - args.push_back("1"); + args.push_back(std::to_string(DNNL_BINARY_MUL)); + // Args: H, W args.push_back(GetShapeString(ishape)); return args; } diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 0c81be255711..d1190df91375 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -64,7 +64,7 @@ inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); break; default: - assert(true); + LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); break; } return data_md; @@ -264,6 +264,10 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo free(weight); } +// should comply with src/relay/backend/contrib/dnnl/codegen.cc +#define DNNL_BINARY_ADD 0 +#define DNNL_BINARY_MUL 1 + extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_type, std::vector shape) { using dt = memory::data_type; @@ -279,13 +283,14 @@ extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_ algorithm algo = algorithm::undef; switch (algo_type) { - case 0: + case DNNL_BINARY_ADD: algo = algorithm::binary_add; break; - case 1: + case DNNL_BINARY_MUL: algo = algorithm::binary_mul; + break; default: - assert(true); + LOG(FATAL) << "Unsupported dnnl algorithm: " << algo_type; break; } diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index a29d503746ae..522313ae5a64 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -26,6 +26,7 @@ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #include +#include #include