Skip to content

Commit

Permalink
store/restore rt info before/after transformations in the prepostproc…
Browse files Browse the repository at this point in the history
…essing
  • Loading branch information
itikhono committed Feb 17, 2025
1 parent f3c26ff commit 8073741
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 46 deletions.

This file was deleted.

48 changes: 42 additions & 6 deletions src/core/src/preprocess/pre_post_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,58 @@
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
#include "transformations/common_optimizations/ric_fusion.hpp"
#include "transformations/common_optimizations/rt_info_cleanup.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
#include "transformations/utils/utils.hpp"

namespace {

struct RTInfoCache {
template <typename Func>
void traverse(const std::shared_ptr<ov::Model>& model, Func&& func) {
for (const auto& op : model->get_ordered_ops()) {
func(op);
if (const auto& multi_subgraph_op = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph)
traverse(sub_graph, func);
}
}
}
}

void store(const std::shared_ptr<ov::Model>& model) {
traverse(model, [this](const std::shared_ptr<ov::Node>& op) {
auto& rt_info = op->get_rt_info();
m_rt_info_cache[op.get()] = rt_info;
rt_info.clear();
});
}

void restore(const std::shared_ptr<ov::Model>& model) {
traverse(model, [this](const std::shared_ptr<ov::Node>& op) {
auto it = m_rt_info_cache.find(op.get());
if (it != m_rt_info_cache.end()) {
op->get_rt_info() = it->second;
}
});
}

std::unordered_map<ov::Node*, ov::RTMap> m_rt_info_cache;
};

void transformation_pipeline(std::shared_ptr<ov::Model>& model) {
using namespace ov;
using namespace ov::pass;
using namespace ov::element;

// 0. Store RT info to not affect plugin compilation
RTInfoCache rt_info_cache;
rt_info_cache.store(model);

Manager manager("pre_post_processing");
manager.set_per_pass_validation(false);
REGISTER_PASS(manager, InitNodeInfo)

// 1. Set "disable_const_folding" attribute
REGISTER_PASS(manager, MarkDequantization, TypeVector{i8, u8, i4, u4, nf4, f4e2m1, f8e4m3, f8e5m2, f8e8m0});
Expand All @@ -55,10 +91,10 @@ void transformation_pipeline(std::shared_ptr<ov::Model>& model) {

// 3. CF call due to detected perf degradations
REGISTER_PASS(manager, ConstantFolding)

// 4. RT info cleanup to not affect plugin compilation
REGISTER_PASS(manager, RTInfoCleanup);
manager.run_passes(model);

// 4. Restore old RT info to not affect plugin compilation
rt_info_cache.restore(model);
}

} // namespace
Expand Down

0 comments on commit 8073741

Please sign in to comment.