Skip to content

Commit

Permalink
[TRANSFORMATIONS] Fix Optional to match even with no inputs (openvino…
Browse files Browse the repository at this point in the history
…toolkit#23471)

[TRANSFORMATIONS] Fix Optional to match even with no inputs

### Details:
The Optional pattern type may create a wrong pattern to match if no
inputs are provided to the Optional node. If no inputs present to the
Optional type, it will not create an alternative branch(es) to check
against resulting in the incorrect matching.

Fix that by adding a check for the number of inputs being 0.

Do a minor refactoring/renaming for the readability purposes.

### Tickets:
 CSV-133523

Signed-off-by: Andrii Staikov <[email protected]>

---------

Signed-off-by: Andrii Staikov <[email protected]>
  • Loading branch information
CuriousPanCake authored and bbielawx committed Apr 12, 2024
1 parent 8817ccc commit 2542f41
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 4 deletions.
42 changes: 38 additions & 4 deletions src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,52 @@
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

using namespace ov::pass::pattern::op;

/*
┌──────────────┐
│ Relu │
┌──────────────┐ └──────┬───────┘
│ Relu │ │
└──────┬───────┘ ┌──────┴───────┐ ┌──────────────┐
│ │WrapType<Relu>│ │ Relu │
┌──────┴───────┐ └──────┬───────┘ └───────┬──────┘
│Optional<Relu>│ Unfolds into │ │
└──────┬───────┘ └────────┐ ┌────────┘
│ │ │
┌─┴─┐ ┌┴──────┴┐
│ABS│ │ Or │
└───┘ └────┬───┘
┌─┴─┐
│ABS│
└───┘
In case there're no inputs to the Optional, there's no second branch hence no need in the
Or node and we may omit it leaving only the WrapType node with the Optional entry inside.
*/

std::vector<ov::DiscreteTypeInfo> ov::pass::pattern::op::Optional::get_optional_types() const {
return optional_types;
}

bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) {
ov::OutputVector or_in_values = input_values();
auto wrap_node = std::make_shared<ov::pass::pattern::op::WrapType>(optional_types, m_predicate, or_in_values);
or_in_values.push_back(wrap_node);
// Turn the Optional node into WrapType node to create a case where the Optional node is present
ov::OutputVector input_values_to_optional = input_values();
size_t num_input_values_to_optional = input_values_to_optional.size();
auto wrap_node = std::make_shared<WrapType>(optional_types, m_predicate, input_values_to_optional);

// Either continue using the WrapType if there're no inputs to it or create an Or node,
// if there're other inputs to Optional creating another "branch" for matching.
// Use only the 0th input as a "data" input. (To be changed or considered when Optional
// starts supporting multiple inputs)
auto pattern = num_input_values_to_optional == 0 ? std::static_pointer_cast<Pattern>(wrap_node)
: std::static_pointer_cast<Pattern>(std::make_shared<Or>(
OutputVector{wrap_node, input_values_to_optional[0]}));

if (matcher->match_value(std::make_shared<ov::pass::pattern::op::Or>(or_in_values), graph_value)) {
if (matcher->match_value(pattern, graph_value) || num_input_values_to_optional == 0) {
auto& pattern_map = matcher->get_pattern_value_map();
if (pattern_map.count(wrap_node)) {
pattern_map[shared_from_this()] = graph_value;
Expand Down
64 changes: 64 additions & 0 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/exp.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reduce_sum.hpp"
Expand Down Expand Up @@ -508,6 +510,68 @@ TEST(pattern, matching_optional) {
std::make_shared<op::v0::Abs>(c)));
}

TEST(pattern, optional_full_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));

auto pattern_add = ov::pass::pattern::optional<op::v1::Add>();
auto pattern_relu = std::make_shared<op::v0::Relu>(pattern_add->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu, model_relu));
}

TEST(pattern, optional_half_match) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));

auto pattern_relu = ov::pass::pattern::optional<op::v0::Relu>();
auto pattern_relu1 = std::make_shared<op::v0::Relu>(pattern_relu->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(pattern_relu1, model_relu));
}

TEST(pattern, optional_testing) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_add = std::make_shared<op::v1::Add>(model_input1->output(0), model_input2->output(0));
auto model_relu = std::make_shared<op::v0::Relu>(model_add->output(0));
auto model_abs = std::make_shared<op::v0::Abs>(model_add->output(0));

TestMatcher tm;

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Exp>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Cos>(model_add), model_add));

ASSERT_TRUE(
tm.match(ov::pass::pattern::optional<op::v0::Abs>(model_abs), std::make_shared<op::v0::Abs>(model_abs)));
ASSERT_FALSE(
tm.match(ov::pass::pattern::optional<op::v0::Abs>(model_abs), std::make_shared<op::v0::Relu>(model_abs)));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_abs),
std::make_shared<op::v0::Relu>(model_abs)));

ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Exp>(model_add), model_abs));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Abs>(model_add), model_abs));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(model_relu),
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(model_relu),
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));
}

TEST(pattern, mean) {
// construct mean
TestMatcher n;
Expand Down

0 comments on commit 2542f41

Please sign in to comment.