Skip to content

Commit

Permalink
Create a WrapType node instead of Or node in case there're no inputs to
Browse files Browse the repository at this point in the history
Optional
  • Loading branch information
CuriousPanCake committed Mar 15, 2024
1 parent c4bd8d1 commit 9258489
Showing 1 changed file with 36 additions and 5 deletions.
41 changes: 36 additions & 5 deletions src/core/src/pattern/op/optional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

using namespace ov::pass::pattern::op;


/*
┌──────────────┐
│ Relu │
┌──────────────┐ └──────┬───────┘
│ Relu │ │
└──────┬───────┘ ┌──────┴───────┐ ┌──────────────┐
│ │Optional<Relu>│ │ Relu │
┌──────┴───────┐ └──────────┬───┘ ┼ └──┬───────────┘
│Optional<Relu>│ Unfolds into │ │
└──────┬───────┘ └────┐ ┌───┘
│ │ │
┌─┴─┐ ┌┴──────┴┐
│ABS│ │ Or │
└───┘ └────┬───┘
┌─┴─┐
│ABS│
└───┘
In case there're no inputs to the Optional, there's no second branch
hence no need in the Or node and we may use WrapType with the entry of the Optional node
*/

std::vector<ov::DiscreteTypeInfo> ov::pass::pattern::op::Optional::get_optional_types() const {
return optional_types;
}
Expand All @@ -19,13 +45,18 @@ bool ov::pass::pattern::op::Optional::match_value(Matcher* matcher,
ov::OutputVector input_values_to_optional = input_values();
size_t num_input_values_to_optional = input_values_to_optional.size();
auto wrap_node =
std::make_shared<ov::pass::pattern::op::WrapType>(optional_types, m_predicate, input_values_to_optional);
std::make_shared<WrapType>(optional_types, m_predicate, input_values_to_optional);

// Add the newly created WrapType node to the list containing its inputs and create an Or node with the list
input_values_to_optional.push_back(wrap_node);
auto or_node = std::make_shared<ov::pass::pattern::op::Or>(input_values_to_optional);
// Using only the 0th input as a "data" input. (To be changed or considered when Optional starts supporting multiple inputs)
// Either continue using the WrapType if there're no inputs to it or create an Or node if
// there're other inputs to Optional creating another "branch" for matching
OutputVector input_values_to_or {wrap_node, input_values_to_optional[0]};
auto pattern = num_input_values_to_optional == 0 ?
std::static_pointer_cast<Pattern>(wrap_node) :
std::static_pointer_cast<Pattern>(std::make_shared<Or>(input_values_to_or));

if (matcher->match_value(or_node, graph_value) || num_input_values_to_optional == 0) {
// Add the newly created WrapType node to the list containing its inputs and create an Or node with the list
if (matcher->match_value(pattern, graph_value) || num_input_values_to_optional == 0) {
auto& pattern_map = matcher->get_pattern_value_map();
if (pattern_map.count(wrap_node)) {
pattern_map[shared_from_this()] = graph_value;
Expand Down

0 comments on commit 9258489

Please sign in to comment.