Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transformations] Added If operation to NMS path propagation for ignore negative indices in Gather #23451

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/if.hpp"
#include "openvino/op/non_max_suppression.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
Expand All @@ -18,7 +19,9 @@
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/broadcast_base.hpp"
#include "openvino/op/util/gather_base.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/nms_selected_indices.hpp"

Expand Down Expand Up @@ -60,14 +63,53 @@ class PropagateNMSPath : public pass::MatcherPass {
ov::op::v1::VariadicSplit,
op::util::GatherBase,
ov::op::v0::Concat,
ov::op::v0::Convert>();
ov::op::v0::Convert,
ov::op::v8::If>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto propagate_path = [](const ov::OutputVector& input_nodes, ov::Node* target_node) {
if (any_of(input_nodes.begin(), input_nodes.end(), [](const Output<Node>& output) {
return ov::has_nms_selected_indices(output.get_node());
})) {
ov::set_nms_selected_indices(target_node);
}
};
auto handle_params = [&propagate_path](std::shared_ptr<ov::op::util::MultiSubGraphOp> node,
std::shared_ptr<ov::Model> body,
int body_index) {
const auto& params = body->get_parameters();
for (auto input_desc : node->get_input_descriptions(body_index)) {
auto param = params[input_desc->m_body_parameter_index];
auto input_node = node->input(input_desc->m_input_index).get_source_output();
propagate_path({input_node}, param.get());
}
};
auto handle_results = [&propagate_path](std::shared_ptr<ov::op::util::MultiSubGraphOp> node,
std::shared_ptr<ov::Model> body,
int body_index) {
const auto& results = body->get_results();
for (auto output_desc : node->get_output_descriptions(body_index)) {
auto result = results[output_desc->m_body_value_index];
const auto& result_inputs = result->input_values();
auto output_node = node->output(output_desc->m_output_index).get_node();
propagate_path(result_inputs, output_node);
}
};

auto node = m.get_match_root();
const auto& inputs = node->input_values();
if (any_of(inputs.begin(), inputs.end(), [](const Output<Node>& output) {
return ov::has_nms_selected_indices(output.get_node());
})) {
ov::set_nms_selected_indices(node.get());
if (ov::is_type<ov::op::util::MultiSubGraphOp>(node)) {
auto multi_subgraph_op = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(node);
const auto& models = multi_subgraph_op->get_functions();

for (size_t body_idx = 0; body_idx < models.size(); ++body_idx) {
handle_params(multi_subgraph_op, models[body_idx], static_cast<int>(body_idx));
ov::pass::Manager manager;
manager.register_pass<ov::pass::PropagateNMSPath>();
manager.run_passes(models[body_idx]);
handle_results(multi_subgraph_op, models[body_idx], static_cast<int>(body_idx));
}
} else {
const auto& inputs = node->input_values();
propagate_path(inputs, node.get());
}
return false;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,68 @@ TEST(TransformationTests, test_convert_to_unsigned_nms_gather_3) {
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(count_ops_of_type<ov::op::v0::Convert>(f), 0);
}

TEST(TransformationTests, test_convert_to_unsigned_nms_gather_with_if_condition) {
auto boxes = make_shared<opset8::Parameter>(element::f32, PartialShape{1, -1, 4});
auto scores = make_shared<opset8::Parameter>(element::f32, PartialShape{1, 1, -1});
auto nms = make_shared<opset8::NonMaxSuppression>(boxes, scores);

auto gather = make_shared<opset8::Gather>(nms->output(0),
opset8::Constant::create(element::i32, Shape{1}, {2}),
opset8::Constant::create(element::i32, Shape{1}, {0}));

auto shape_of = make_shared<opset8::ShapeOf>(gather);
auto gather_shape = make_shared<opset8::Gather>(shape_of,
opset8::Constant::create(element::i32, Shape{1}, {0}),
opset8::Constant::create(element::i32, Shape{1}, {0}));
auto equal = make_shared<opset8::Equal>(gather_shape, opset8::Constant::create(element::i64, Shape{1}, {1}));
auto if_op = make_shared<opset8::If>(equal);

auto input_then = make_shared<opset8::Parameter>(element::i32, PartialShape{-1, 1});

auto start = opset8::Constant::create(element::i32, Shape{1}, {3});
auto stop = opset8::Constant::create(element::i32, Shape{1}, {4});
auto step = opset8::Constant::create(element::i32, Shape{1}, {1});
auto slice = make_shared<opset8::Slice>(input_then, start, stop, step);

auto then_op_result = make_shared<op::v0::Result>(slice);
auto body_then_function = make_shared<Model>(NodeVector{then_op_result}, ParameterVector{input_then});

auto input_else = make_shared<opset8::Parameter>(element::i32, PartialShape{-1, 1});
auto reshape =
make_shared<opset8::Reshape>(input_else, opset8::Constant::create(element::i32, Shape{1}, {-1}), true);
auto else_op_result = make_shared<op::v0::Result>(reshape);
auto body_else_function = make_shared<Model>(NodeVector{else_op_result}, ParameterVector{input_else});

if_op->set_then_body(body_then_function);
if_op->set_else_body(body_else_function);
if_op->set_input(gather, input_then, input_else);

auto result_if = if_op->set_output(then_op_result, else_op_result);

auto begin = opset8::Constant::create(element::i32, Shape{1}, {3});
auto end = opset8::Constant::create(element::i32, Shape{1}, {4});
auto strides = opset8::Constant::create(element::i32, Shape{1}, {1});
auto ss_node =
make_shared<opset8::StridedSlice>(result_if, begin, end, strides, vector<int64_t>{1, 0}, vector<int64_t>{1, 0});

auto data = make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
auto axis = opset8::Constant::create(element::i32, Shape{1}, {0});
auto target_gather = make_shared<opset8::Gather>(data, ss_node, axis);

shared_ptr<Model> f = make_shared<Model>(NodeVector{target_gather}, ParameterVector{boxes, scores, data});

pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertNmsGatherPathToUnsigned>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));

const auto& ops = f->get_ops();
const auto& gather_it = find(ops.begin(), ops.end(), target_gather);
ASSERT_NE(gather_it, ops.end());

const auto& rti = (*gather_it)->get_rt_info();
const auto& reverse = rti.find("dontReverseIndices");
ASSERT_NE(reverse, rti.end());
}
Loading