Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Integrate oneDNN binary primitive support for forward add, subtract, multiply, divide. #20713

Merged
merged 7 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/operator/nn/dnnl/dnnl_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs, const NDArray& outp
bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector<NDArray>& inputs);
bool SupportDNNLReshape(const NDArray& input, const NDArray& output);
bool SupportDNNLStack(const std::vector<NDArray>& inputs);
bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
86 changes: 86 additions & 0 deletions src/operator/nn/dnnl/dnnl_binary-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file dnnl_binary-inl.h
* \author: Adam Grabowski, [email protected]
*/

#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
#define MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_

#if MXNET_USE_ONEDNN == 1
#include "./dnnl_base-inl.h"
#include "./dnnl_ops-inl.h"
#include <vector>

#include "../../tensor/elemwise_binary_broadcast_op.h"

namespace mxnet {
namespace op {

using binary_fwd_t = dnnl::binary;
using binary_fwd_pd_t = dnnl::binary::primitive_desc;

class DNNLBinaryOpFwd {
public:
template <dnnl::algorithm alg>
static DNNLBinaryOpFwd& GetBinaryOpForward(const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs);
DNNLBinaryOpFwd(const dnnl::algorithm alg,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs);

void Execute(const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

private:
std::shared_ptr<binary_fwd_t> fwd;
std::shared_ptr<binary_fwd_pd_t> fwd_pd;
};

template <dnnl::algorithm alg>
DNNLBinaryOpFwd& DNNLBinaryOpFwd::GetBinaryOpForward(const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
using binary_op_fwd_map = std::unordered_map<OpSignature, DNNLBinaryOpFwd, OpHash>;
#if DMLC_CXX11_THREAD_LOCAL
static thread_local binary_op_fwd_map fwds;
#else
static MX_THREAD_LOCAL binary_op_fwd_map fwds;
#endif
OpSignature key;
key.AddSign(static_cast<int>(alg));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about "attrs" in the key ?
I think we probably should add attrs to key or remove it form DNNLBinaryOpFwd constructor parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to remove attrs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attrs removed where it was possible

key.AddSign(inputs[0]);
key.AddSign(inputs[1]);
key.AddSign(outputs[0]);

auto it = fwds.find(key);
if (it == fwds.end()) {
const DNNLBinaryOpFwd fwd(alg, inputs, outputs);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_ONEDNN == 1
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_
78 changes: 78 additions & 0 deletions src/operator/nn/dnnl/dnnl_binary.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file dnnl_binary.cc
* \author: Adam Grabowski, [email protected]
*/

#if MXNET_USE_ONEDNN == 1
#include "./dnnl_binary-inl.h"

namespace mxnet {
namespace op {

DNNLBinaryOpFwd::DNNLBinaryOpFwd(const dnnl::algorithm alg,
const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs) {
auto src0_desc = inputs[0].GetDNNLData()->get_desc();
auto src1_desc = inputs[1].GetDNNLData()->get_desc();
auto dst_desc = outputs[0].GetDNNLData()->get_desc();

dnnl::binary::desc fwd_desc(alg, src0_desc, src1_desc, dst_desc);
fwd_pd = std::make_shared<binary_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
fwd = std::make_shared<binary_fwd_t>(*fwd_pd);
}

void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
auto engine = mxnet::CpuEngine::Get()->get_engine();
auto src0 = inputs[0].GetDNNLData();
auto src1 = inputs[1].GetDNNLData();
dnnl_output_t out_mem;
if (outputs[0].GetDNNLData()->get_data_handle() == inputs[1].GetDNNLData()->get_data_handle())
out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[1]);
else
out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]);

dnnl_args_map_t args = {
{DNNL_ARG_SRC_0, *src0},
{DNNL_ARG_SRC_1, *src1},
{DNNL_ARG_DST, *out_mem.second},
};

DNNLStream::Get()->RegisterPrimArgs(*fwd, args);
CommitOutput(outputs[0], out_mem);
DNNLStream::Get()->Submit();
}

bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
auto dtype = inputs[0].dtype();
auto ndim_0 = inputs[0].shape().ndim();
auto ndim_1 = inputs[1].shape().ndim();
return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 &&
inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
dtype == mshadow::kFloat32 && dtype == inputs[1].dtype();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please check if oneDNN supports bfloat, if yes please create separate PR for it.

}

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_ONEDNN == 1
47 changes: 47 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,53 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
}
}

#if MXNET_USE_ONEDNN == 1
inline bool NumpyBinaryBroadcastStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);

return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}

void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

template <typename OP>
void NumpyBinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<mxnet::NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<mxnet::NDArray>& outputs) {
if (SupportDNNLBinary(inputs)) {
const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
return;
}
using namespace op::mshadow_op;
std::vector<mxnet::TBlob> in_data = {inputs[0].data(), inputs[1].data()};
std::vector<mxnet::TBlob> out_data = {outputs[0].data()};
if (std::is_same<OP, plus>::value) {
NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_plus, mixed_plus>(
attrs, ctx, in_data, req, out_data);
} else if (std::is_same<OP, minus>::value) {
NumpyBinaryBroadcastCompute<cpu, OP, mixed_minus, mixed_rminus>(
attrs, ctx, in_data, req, out_data);
} else if (std::is_same<OP, mul>::value) {
NumpyBinaryBroadcastComputeWithBool<cpu, OP, mixed_mul, mixed_mul>(
attrs, ctx, in_data, req, out_data);
} else if (std::is_same<OP, div>::value) {
NumpyDivideBroadcastComputeCPU(attrs, ctx, in_data, req, out_data);
}
}
#endif // MXNET_USE_ONEDNN

#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
op::mshadow_op::plus,
op::mshadow_op::mixed_plus,
op::mshadow_op::mixed_plus>)
#if MXNET_USE_ONEDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::plus>)
.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"});

NNVM_REGISTER_OP(_backward_npi_broadcast_add)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
op::mshadow_op::mul,
op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>)
#if MXNET_USE_ONEDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::mul>)
.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_sub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
op::mshadow_op::minus,
op::mshadow_op::mixed_minus,
op::mshadow_op::mixed_rminus>)
#if MXNET_USE_ONEDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::minus>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about mixed version? is it work properly for GPU if oneDNN is enabled (default configuration). Could you check if there is any test for it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneDNN dispatch is only taken under consideration after dev_mask == mshadow::cpu::kDevMask condition is met, thus not affecting GPU workflow.

.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"});

NNVM_REGISTER_OP(_backward_npi_broadcast_sub)
Expand Down
14 changes: 14 additions & 0 deletions src/operator/numpy/np_true_divide.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs,
return true;
}

#if MXNET_USE_ONEDNN == 1
void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
TrueDivideBroadcastCompute<cpu>(attrs, ctx, inputs, req, outputs);
}
#endif // MXNET_USE_ONEDNN

NNVM_REGISTER_OP(_npi_true_divide)
.set_num_inputs(2)
.set_num_outputs(1)
Expand All @@ -79,6 +89,10 @@ NNVM_REGISTER_OP(_npi_true_divide)
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", TrueDivideBroadcastCompute<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<FComputeEx>("FComputeEx<cpu>", NumpyBinaryOperatorComputeExCPU<op::mshadow_op::div>)
.set_attr<FInferStorageType>("FInferStorageType", NumpyBinaryBroadcastStorageType)
#endif // MXNET_USE_ONEDNN
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"})
.add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
.add_argument("rhs", "NDArray-or-Symbol", "Divisor array");
Expand Down
41 changes: 41 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
#if MXNET_USE_ONEDNN == 1
if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
#else
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
#endif // MXNET_USE_ONEDNN == 1
}
if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) {
dispatched =
Expand All @@ -116,8 +122,14 @@ inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs,
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
#if MXNET_USE_ONEDNN == 1
if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet())
dispatched = storage_type_assign(
&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
#else
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
#endif // MXNET_USE_ONEDNN == 1
}
if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
(lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
Expand Down Expand Up @@ -788,6 +800,35 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
}
}

#if MXNET_USE_ONEDNN == 1
template <dnnl::algorithm alg>
void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

// template struct converting op::mshadow_op to dnnl::algorithm
template <typename OP>
struct DNNLAlgorithm {};
template <>
struct DNNLAlgorithm<op::mshadow_op::plus> {
static const dnnl::algorithm value = dnnl::algorithm::binary_add;
};
template <>
struct DNNLAlgorithm<op::mshadow_op::minus> {
static const dnnl::algorithm value = dnnl::algorithm::binary_sub;
};
template <>
struct DNNLAlgorithm<op::mshadow_op::mul> {
static const dnnl::algorithm value = dnnl::algorithm::binary_mul;
};
template <>
struct DNNLAlgorithm<op::mshadow_op::div> {
static const dnnl::algorithm value = dnnl::algorithm::binary_div;
};
#endif // MXNET_USE_ONEDNN == 1

#define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
Expand Down
Loading