Skip to content

Commit

Permalink
fix segfault on a wrong vector access
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousPanCake committed Mar 18, 2024
1 parent 5792f4b commit 96c4338
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 35 deletions.
56 changes: 27 additions & 29 deletions src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,27 @@

using namespace ov::pass::pattern::op;


/*
┌──────────────┐
│ Relu │
┌──────────────┐ └──────┬───────┘
│ Relu │ │
/*
┌──────────────┐
│ Relu │
┌──────────────┐ └──────┬───────┘
│ Relu │ │
└──────┬───────┘ ┌──────┴───────┐ ┌──────────────┐
│ │Optional<Relu>│ │ Relu
┌──────┴───────┐ └─────────────┘ └─────────────┘
│Optional<Relu>│ Unfolds into
└──────┬───────┘ └────┐ ┌───
│ │ │
┌─┴─┐ ┌┴──────┴┐
│ABS│ │ Or │
└───┘ └────┬───┘
┌─┴─┐
│ABS│
└───┘
│ │WrapType<Relu>│ │ Relu
┌──────┴───────┐ └─────────────┘ └─────────────┘
│Optional<Relu>│ Unfolds into │
└──────┬───────┘ └────────┐ ┌────────┘
│ │ │
┌─┴─┐ ┌┴──────┴┐
│ABS│ │ Or │
└───┘ └────┬───┘
┌─┴─┐
│ABS│
└───┘
In case there're no inputs to the Optional, there's no second branch
hence no need in the Or node and we may use WrapType with the entry of the Optional node
In case there're no inputs to the Optional, there's no second branch hence no need in the
Or node and we may omit it leaving only the WrapType node with the Optional entry inside.
*/

std::vector<ov::DiscreteTypeInfo> ov::pass::pattern::op::Optional::get_optional_types() const {
Expand All @@ -44,16 +43,15 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
// Turn the Optional node into WrapType node to create a case where the Optional node is present
ov::OutputVector input_values_to_optional = input_values();
size_t num_input_values_to_optional = input_values_to_optional.size();
auto wrap_node =
std::make_shared<WrapType>(optional_types, m_predicate, input_values_to_optional);
auto wrap_node = std::make_shared<WrapType>(optional_types, m_predicate, input_values_to_optional);

// Using only the 0th input as a "data" input. (To be changed or considered when Optional starts supporting multiple inputs)
// Either continue using the WrapType if there're no inputs to it or create an Or node if
// there're other inputs to Optional creating another "branch" for matching
OutputVector input_values_to_or {wrap_node, input_values_to_optional[0]};
auto pattern = num_input_values_to_optional == 0 ?
std::static_pointer_cast<Pattern>(wrap_node) :
std::static_pointer_cast<Pattern>(std::make_shared<Or>(input_values_to_or));
// Either continue using the WrapType if there're no inputs to it or create an Or node,
// if there're other inputs to Optional creating another "branch" for matching.
// Use only the 0th input as a "data" input. (To be changed or considered when Optional
// starts supporting multiple inputs)
auto pattern = num_input_values_to_optional == 0 ? std::static_pointer_cast<Pattern>(wrap_node)
: std::static_pointer_cast<Pattern>(std::make_shared<Or>(
OutputVector{wrap_node, input_values_to_optional[0]}));

// Add the newly created WrapType node to the list containing its inputs and create an Or node with the list
if (matcher->match_value(pattern, graph_value) || num_input_values_to_optional == 0) {
Expand Down
14 changes: 8 additions & 6 deletions src/core/tests/pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/exp.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reduce_sum.hpp"
Expand Down Expand Up @@ -538,7 +540,7 @@ TEST(pattern, optional_half_match) {
ASSERT_TRUE(tm.match(pattern_relu1, model_relu));
}

TEST(pattern, optional_new_test) {
TEST(pattern, optional_testing) {
Shape shape{};
auto model_input1 = std::make_shared<op::v0::Parameter>(element::i32, shape);
auto model_input2 = std::make_shared<op::v0::Parameter>(element::i32, shape);
Expand All @@ -548,10 +550,10 @@ TEST(pattern, optional_new_test) {

TestMatcher tm;

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v1::Divide, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v1::Multiply>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v1::Divide, op::v1::Multiply>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Exp>(model_add), model_add));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Cos>(model_add), model_add));

ASSERT_TRUE(
tm.match(ov::pass::pattern::optional<op::v0::Abs>(model_abs), std::make_shared<op::v0::Abs>(model_abs)));
Expand All @@ -560,8 +562,8 @@ TEST(pattern, optional_new_test) {
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Abs, op::v0::Relu>(model_abs),
std::make_shared<op::v0::Relu>(model_abs)));

ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v1::Divide>(model_add), model_abs));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v1::Divide, op::v0::Abs>(model_add), model_abs));
ASSERT_FALSE(tm.match(ov::pass::pattern::optional<op::v0::Exp>(model_add), model_abs));
ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Exp, op::v0::Abs>(model_add), model_abs));

ASSERT_TRUE(tm.match(ov::pass::pattern::optional<op::v0::Relu>(model_relu),
std::make_shared<op::v0::Relu>(std::make_shared<op::v0::Relu>(model_add))));
Expand Down

0 comments on commit 96c4338

Please sign in to comment.