Skip to content

Commit

Permalink
[CPU] [ARM64] jit equal
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 21, 2024
1 parent 12404fc commit 6cee113
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,57 @@ std::set<std::vector<element::Type>> jit_divide_emitter::get_supported_precision
return {{element::f32, element::f32}};
}

/// EQUAL ///
jit_equal_emitter::jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {
prepare_table();
}
jit_equal_emitter::jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}

size_t jit_equal_emitter::get_inputs_count() const { return 2; }

size_t jit_equal_emitter::get_aux_vecs_count() const { return 1; }

size_t jit_equal_emitter::get_aux_gprs_count() const { return 1; }

std::set<std::vector<element::Type>> jit_equal_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

void jit_equal_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_equal_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg src1 = TReg(in_vec_idxs[0]);
const TReg src2 = TReg(in_vec_idxs[1]);
const TReg dst = TReg(out_vec_idxs[0]);
const TReg aux = TReg(aux_vec_idxs[0]);

h->fcmeq(dst.s, src1.s, src2.s);

h->ld1r(aux.s, table_val2("one"));
h->and_(dst.b16, dst.b16, aux.b16);
}

void jit_equal_emitter::register_table_entries() {
push_arg_entry_of("one", 0x3f800000, true);
}

/// MUL_ADD ///
jit_mul_add_emitter::jit_mul_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,33 @@ class jit_divide_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_equal_emitter : public jit_emitter {
public:
jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_equal_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& n);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;

void register_table_entries() override;
};

class jit_mul_add_emitter : public jit_emitter {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ bool JitEltwiseExecutor::isSupported(
Algorithm::EltwiseAdd,
Algorithm::EltwiseClamp,
Algorithm::EltwiseDivide,
Algorithm::EltwiseEqual,
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Algorithm::EltwisePowerStatic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, ov::intel_cpu::aarch64::jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter),
OV_CASE(Algorithm::EltwiseEqual, ov::intel_cpu::aarch64::jit_equal_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, ov::intel_cpu::aarch64::jit_power_static_emitter),
Expand Down Expand Up @@ -766,6 +767,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter),
OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ std::vector<EltwiseTypes> eltwise_op_types = {
EltwiseTypes::FLOOR_MOD,
EltwiseTypes::SQUARED_DIFF,
EltwiseTypes::POWER,
EltwiseTypes::MOD
EltwiseTypes::MOD,
EltwiseTypes::EQUAL_OP
};

std::vector<EltwiseTypes> eltwise_op_types_dynamic = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ enum EltwiseTypes {
BITWISE_AND,
BITWISE_NOT,
BITWISE_OR,
BITWISE_XOR
BITWISE_XOR,
EQUAL_OP
};

enum SqueezeOpType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/op/bitwise_or.hpp"
#include "openvino/op/bitwise_xor.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/erf.hpp"
#include "openvino/op/floor_mod.hpp"
#include "openvino/op/mod.hpp"
Expand Down Expand Up @@ -51,6 +52,8 @@ std::shared_ptr<ov::Node> make_eltwise(const ov::Output<Node>& in0,
return std::make_shared<ov::op::v13::BitwiseOr>(in0, in1);
case ov::test::utils::EltwiseTypes::BITWISE_XOR:
return std::make_shared<ov::op::v13::BitwiseXor>(in0, in1);
case ov::test::utils::EltwiseTypes::EQUAL_OP:
return std::make_shared<ov::op::v1::Equal>(in0, in1);
default: {
OPENVINO_THROW("Incorrect type of Eltwise operation");
}
Expand Down
3 changes: 3 additions & 0 deletions src/tests/test_utils/common_test_utils/src/test_enums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ std::ostream& operator<<(std::ostream& os, const ov::test::utils::EltwiseTypes t
case ov::test::utils::EltwiseTypes::BITWISE_XOR:
os << "BitwiseXor";
break;
case ov::test::utils::EltwiseTypes::EQUAL_OP:
os << "Equal";
break;
default:
throw std::runtime_error("NOT_SUPPORTED_OP_TYPE");
}
Expand Down

0 comments on commit 6cee113

Please sign in to comment.