From 96c4338cc12a402f2d0f6e7bc59691c62007b08d Mon Sep 17 00:00:00 2001 From: Andrii Staikov Date: Mon, 18 Mar 2024 12:33:55 +0330 Subject: [PATCH] fix segfault on a wrong vector access --- src/core/src/pattern/op/optional.cpp | 56 ++++++++++++++-------------- src/core/tests/pattern.cpp | 14 ++++--- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/core/src/pattern/op/optional.cpp b/src/core/src/pattern/op/optional.cpp index d34586a970cc36..612ec6738afc58 100644 --- a/src/core/src/pattern/op/optional.cpp +++ b/src/core/src/pattern/op/optional.cpp @@ -10,28 +10,27 @@ using namespace ov::pass::pattern::op; - -/* - ┌──────────────┐ - │ Relu │ - ┌──────────────┐ └──────┬───────┘ - │ Relu │ │ +/* + ┌──────────────┐ + │ Relu │ + ┌──────────────┐ └──────┬───────┘ + │ Relu │ │ └──────┬───────┘ ┌──────┴───────┐ ┌──────────────┐ - │ │Optional│ │ Relu │ - ┌──────┴───────┐ └──────────┬───┘ ┼ └──┬───────────┘ - │Optional│ Unfolds into │ │ - └──────┬───────┘ └────┐ ┌───┘ - │ │ │ - ┌─┴─┐ ┌┴──────┴┐ - │ABS│ │ Or │ - └───┘ └────┬───┘ - │ - ┌─┴─┐ - │ABS│ - └───┘ + │ │WrapType│ │ Relu │ + ┌──────┴───────┐ └──────┬───────┘ └───────┬──────┘ + │Optional│ Unfolds into │ │ + └──────┬───────┘ └────────┐ ┌────────┘ + │ │ │ + ┌─┴─┐ ┌┴──────┴┐ + │ABS│ │ Or │ + └───┘ └────┬───┘ + │ + ┌─┴─┐ + │ABS│ + └───┘ - In case there're no inputs to the Optional, there's no second branch - hence no need in the Or node and we may use WrapType with the entry of the Optional node + In case there're no inputs to the Optional, there's no second branch hence no need in the + Or node and we may omit it leaving only the WrapType node with the Optional entry inside. */ std::vector ov::pass::pattern::op::Optional::get_optional_types() const { @@ -44,16 +43,15 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher, // Turn the Optional node into WrapType node to create a case where the Optional node is present ov::OutputVector input_values_to_optional = input_values(); size_t num_input_values_to_optional = input_values_to_optional.size(); - auto wrap_node = - std::make_shared(optional_types, m_predicate, input_values_to_optional); + auto wrap_node = std::make_shared(optional_types, m_predicate, input_values_to_optional); - // Using only the 0th input as a "data" input. (To be changed or considered when Optional starts supporting multiple inputs) - // Either continue using the WrapType if there're no inputs to it or create an Or node if - // there're other inputs to Optional creating another "branch" for matching - OutputVector input_values_to_or {wrap_node, input_values_to_optional[0]}; - auto pattern = num_input_values_to_optional == 0 ? - std::static_pointer_cast(wrap_node) : - std::static_pointer_cast(std::make_shared(input_values_to_or)); + // Either continue using the WrapType if there're no inputs to it or create an Or node, + // if there're other inputs to Optional creating another "branch" for matching. + // Use only the 0th input as a "data" input. (To be changed or considered when Optional + // starts supporting multiple inputs) + auto pattern = num_input_values_to_optional == 0 ? std::static_pointer_cast(wrap_node) + : std::static_pointer_cast(std::make_shared( + OutputVector{wrap_node, input_values_to_optional[0]})); // Add the newly created WrapType node to the list containing its inputs and create an Or node with the list if (matcher->match_value(pattern, graph_value) || num_input_values_to_optional == 0) { diff --git a/src/core/tests/pattern.cpp b/src/core/tests/pattern.cpp index 08d529a77c0511..7d794aa4a69350 100644 --- a/src/core/tests/pattern.cpp +++ b/src/core/tests/pattern.cpp @@ -18,7 +18,9 @@ #include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/cos.hpp" #include "openvino/op/divide.hpp" +#include "openvino/op/exp.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/parameter.hpp" #include "openvino/op/reduce_sum.hpp" @@ -538,7 +540,7 @@ TEST(pattern, optional_half_match) { ASSERT_TRUE(tm.match(pattern_relu1, model_relu)); } -TEST(pattern, optional_new_test) { +TEST(pattern, optional_testing) { Shape shape{}; auto model_input1 = std::make_shared(element::i32, shape); auto model_input2 = std::make_shared(element::i32, shape); @@ -548,10 +550,10 @@ TEST(pattern, optional_new_test) { TestMatcher tm; - ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); - ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); - ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_add)); ASSERT_TRUE( tm.match(ov::pass::pattern::optional(model_abs), std::make_shared(model_abs))); @@ -560,8 +562,8 @@ TEST(pattern, optional_new_test) { ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_abs), std::make_shared(model_abs))); - ASSERT_FALSE(tm.match(ov::pass::pattern::optional(model_add), model_abs)); - ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_abs)); + ASSERT_FALSE(tm.match(ov::pass::pattern::optional(model_add), model_abs)); + ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_add), model_abs)); ASSERT_TRUE(tm.match(ov::pass::pattern::optional(model_relu), std::make_shared(std::make_shared(model_add))));