Skip to content

Commit

Permalink
Cleanup as a separate matcher to cover Loop/TI/If cases
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Feb 17, 2025
1 parent 95e292b commit 37463e7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ TRANSFORMATIONS_API void mark_as_dequantization_node(const std::shared_ptr<Node>

TRANSFORMATIONS_API bool is_dequantization_node(const std::shared_ptr<const Node>& node);

TRANSFORMATIONS_API void unmark_dequantization_node(const std::shared_ptr<Node>& node);

/**
* @ingroup ov_runtime_attr_api
* @brief DequantizationNode class represents runtime info attribute that marks operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ void ov::mark_as_dequantization_node(const std::shared_ptr<Node>& node) {
rt_info[DequantizationNode::get_type_info_static()] = DequantizationNode();
}

void ov::unmark_dequantization_node(const std::shared_ptr<Node>& node) {
node->get_rt_info().erase(DequantizationNode::get_type_info_static());
}

bool ov::is_dequantization_node(const std::shared_ptr<const Node>& node) {
const auto& rt_info = node->get_rt_info();
return rt_info.find(DequantizationNode::get_type_info_static()) != rt_info.end();
Expand Down
30 changes: 25 additions & 5 deletions src/core/src/preprocess/pre_post_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,49 @@
#include "openvino/core/model.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/true.hpp"
#include "preprocess_impls.hpp"
#include "transformations/common_optimizations/convolution_to_group_convolution_fusion.hpp"
#include "transformations/common_optimizations/disable_random_uniform_constant_folding.hpp"
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
#include "transformations/common_optimizations/ric_fusion.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
#include "transformations/rt_info/dequantization_node.hpp"

namespace {

class RTInfoCleanup : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("RTInfoCleanup");
explicit RTInfoCleanup() {
auto any_op = std::make_shared<ov::pass::pattern::op::True>();
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto root = m.get_match_root();
ov::pass::enable_constant_folding(root);
unmark_dequantization_node(root);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(any_op, "RTInfoCleanup");
this->register_matcher(m, callback);
}
};

void transformation_pipeline(std::shared_ptr<ov::Model>& model) {
using namespace ov;
using namespace ov::pass;
using namespace ov::element;

ov::pass::Manager manager("pre_post_processing");
manager.set_per_pass_validation(false);
REGISTER_PASS(manager, InitNodeInfo)

// 1. Set "disable_const_folding" attribute
manager.register_pass<MarkDequantization>(TypeVector{i8, u8, i4, u4, nf4, f4e2m1, f8e4m3, f8e5m2, f8e8m0});
// manager.register_pass<MarkDequantization>(TypeVector{i8, u8, i4, u4, nf4, f4e2m1, f8e4m3, f8e5m2, f8e8m0});
REGISTER_PASS(manager, DisableShapeOfConstantFolding);
REGISTER_PASS(manager, DisableRandomUniformConstantFolding)
// Mark quantized and f16/bf16 compressed constants to prevent CF for them,
Expand All @@ -50,12 +72,10 @@ void transformation_pipeline(std::shared_ptr<ov::Model>& model) {

// 3. CF call due to detected perf degradations
REGISTER_PASS(manager, ConstantFolding)
manager.run_passes(model);

// 4. RT info cleanup to not affect plugin compilation
for (const auto& op : model->get_ops()) {
enable_constant_folding(op);
}
REGISTER_PASS(manager, RTInfoCleanup);
manager.run_passes(model);
}

} // namespace
Expand Down

0 comments on commit 37463e7

Please sign in to comment.