Skip to content

Commit

Permalink
fix RIC tests, fix py tests
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Feb 14, 2025
1 parent 53aa687 commit 0d34dae
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 39 deletions.
8 changes: 3 additions & 5 deletions src/bindings/python/tests/test_graph/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def test_graph_preprocess_steps(algorithm, color_format1, color_format2, is_fail
"Gather",
"Interpolate",
]
assert len(model_operators) == 15
assert len(model_operators) == 12
assert model.get_output_size() == 1
assert list(model.get_output_shape(0)) == [1, 3, 3, 3]
assert model.get_output_element_type(0) == Type.f32
Expand Down Expand Up @@ -459,10 +459,9 @@ def test_graph_preprocess_postprocess_layout():
"Constant",
"Result",
"Gather",
"Range",
"Transpose",
]
assert len(model_operators) == 14
assert len(model_operators) == 11
assert model.get_output_size() == 1
assert list(model.get_output_shape(0)) == [1, 1, 3, 3]
assert model.get_output_element_type(0) == Type.f32
Expand All @@ -489,9 +488,8 @@ def test_graph_preprocess_reverse_channels():
"Constant",
"Result",
"Gather",
"Range",
]
assert len(model_operators) == 10
assert len(model_operators) == 7
assert model.get_output_size() == 1
assert list(model.get_output_shape(0)) == [1, 2, 2, 2]
assert model.get_output_element_type(0) == Type.f32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/model.hpp"
#include "openvino/core/preprocess/pre_post_process.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/pass/constant_folding.hpp"
Expand Down Expand Up @@ -67,7 +68,7 @@ std::shared_ptr<GroupConvolution> create_group_conv_with_gather(Output<Node> inp
Constant::create(element::i64, Shape{order.size()}, order),
Constant::create(element::i64, Shape{1}, {0}));
return std::make_shared<GroupConvolution>(input,
gather,
ov::util::get_constant_from_source(gather),
ov::Strides{1, 1},
ov::CoordinateDiff{0, 0},
ov::CoordinateDiff{0, 0},
Expand All @@ -81,7 +82,7 @@ std::shared_ptr<Convolution> create_conv_with_gather(Output<Node> input,
Constant::create(element::i64, Shape{order.size()}, order),
Constant::create(element::i64, Shape{1}, {1}));
return std::make_shared<Convolution>(input,
gather,
ov::util::get_constant_from_source(gather),
ov::Strides{1, 1},
ov::CoordinateDiff{0, 0},
ov::CoordinateDiff{0, 0},
Expand Down Expand Up @@ -301,7 +302,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise1) {
{
auto input = create_param({1, 3, 64, 64});
auto gather = create_gather(Constant::create(element::f32, Shape{3, 1, 1}, {0.1, 0.2, 0.3}), {2, 1, 0}, 0);
auto add = std::make_shared<Add>(input, gather);
auto add = std::make_shared<Add>(input, ov::util::get_constant_from_source(gather));
auto conv = create_conv_with_gather(add, {6, 3, 3, 3}, {2, 1, 0});
model_ref = std::make_shared<Model>(NodeVector{conv}, ParameterVector{input});
}
Expand Down Expand Up @@ -371,7 +372,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise4) {
{
auto input = create_param({1, 3, 64, 64});
auto gather = create_gather(create_weights({3, 1, 1}), {2, 1, 0}, 0);
auto add = std::make_shared<Add>(gather, input);
auto add = std::make_shared<Add>(ov::util::get_constant_from_source(gather), input);
auto conv = create_conv_with_gather(add, {6, 3, 3, 3}, {2, 1, 0});
model_ref = std::make_shared<Model>(NodeVector{conv}, ParameterVector{input});
}
Expand All @@ -395,7 +396,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise5) {
{
auto input = create_param({1, 3, 64, 64});
auto gather = create_gather(create_weights({1, 3, 1, 1}), {2, 1, 0}, 1);
auto add = std::make_shared<Add>(gather, input);
auto add = std::make_shared<Add>(ov::util::get_constant_from_source(gather), input);
auto conv = create_conv_with_gather(add, {6, 3, 3, 3}, {2, 1, 0});
model_ref = std::make_shared<Model>(NodeVector{conv}, ParameterVector{input});
}
Expand Down Expand Up @@ -508,7 +509,7 @@ TEST_F(TransformationTestsF, RICFusionTranspose) {
{
auto input = create_param({1, 64, 64, 3});
auto gather = create_gather(create_weights({3}), {2, 1, 0}, 0);
auto add = std::make_shared<Add>(input, gather);
auto add = std::make_shared<Add>(input, ov::util::get_constant_from_source(gather));
auto transpose = std::make_shared<Transpose>(add, Constant::create(element::i64, Shape{4}, {0, 3, 1, 2}));
auto conv = create_conv_with_gather(transpose, {6, 3, 3, 3}, {2, 1, 0});
model_ref = std::make_shared<Model>(NodeVector{conv}, ParameterVector{input});
Expand All @@ -533,7 +534,8 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
{
auto input = create_param({1, 3, 64, 64});
auto fq = create_fq(input);
auto conv = create_conv(fq, create_fq(create_gather(create_weights({6, 3, 3, 3}), {2, 1, 0}, 1)));
auto weights = ov::util::get_constant_from_source(create_gather(create_weights({6, 3, 3, 3}), {2, 1, 0}, 1));
auto conv = create_conv(fq, create_fq(weights));

model_ref = std::make_shared<Model>(NodeVector{conv}, ParameterVector{input});
}
Expand Down Expand Up @@ -565,12 +567,13 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
auto input = create_param({1, 3, 64, 64});
auto fq = create_fq(input);
auto weights_const = create_weights({6, 3, 3, 3});
auto fq_weights = std::make_shared<FakeQuantize>(create_gather(weights_const, {2, 1, 0}, 1),
create_gather(create_weights({1, 3, 1, 1}), {2, 1, 0}, 1),
create_weights({1, 1, 1}),
create_weights({1}),
create_gather(create_weights({3, 1, 1}), {2, 1, 0}, 0),
255);
auto fq_weights = std::make_shared<FakeQuantize>(
ov::util::get_constant_from_source(create_gather(weights_const, {2, 1, 0}, 1)),
ov::util::get_constant_from_source(create_gather(create_weights({1, 3, 1, 1}), {2, 1, 0}, 1)),
create_weights({1, 1, 1}),
create_weights({1}),
ov::util::get_constant_from_source(create_gather(create_weights({3, 1, 1}), {2, 1, 0}, 0)),
255);
auto conv = create_conv(fq, fq_weights);

model_ref = std::make_shared<Model>(NodeVector{conv}, ParameterVector{input});
Expand Down Expand Up @@ -604,12 +607,13 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
auto input = create_param({1, 3, 64, 64});
auto fq = create_fq(input);
auto weights_const = create_weights({3, 1, 1, 3, 3});
auto fq_weights = std::make_shared<FakeQuantize>(create_gather(weights_const, {2, 1, 0}, 0),
create_gather(create_weights({3, 1, 1, 1, 1}), {2, 1, 0}, 0),
create_weights({1, 1, 1}),
create_weights({1}),
create_weights({1}),
255);
auto fq_weights = std::make_shared<FakeQuantize>(
ov::util::get_constant_from_source(create_gather(weights_const, {2, 1, 0}, 0)),
ov::util::get_constant_from_source(create_gather(create_weights({3, 1, 1, 1, 1}), {2, 1, 0}, 0)),
create_weights({1, 1, 1}),
create_weights({1}),
create_weights({1}),
255);
auto gconv = create_group_conv(fq, fq_weights);
auto conv = create_conv_with_gather(gconv, {6, 3, 1, 1}, {2, 1, 0});

Expand Down Expand Up @@ -893,7 +897,7 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiply) {
std::shared_ptr<Node> weights = opset8::Constant::create(element::i8, Shape{4, 3, 1, 1}, {-2});
{
auto scale = opset8::Constant::create(element::f32, Shape{}, {0.2});
auto gather = create_gather(weights, {2, 1, 0}, 1);
auto gather = ov::util::get_constant_from_source(create_gather(weights, {2, 1, 0}, 1));
auto convert = std::make_shared<opset8::Convert>(gather, element::f32);
auto multiply = std::make_shared<opset8::Multiply>(convert, scale);
weights = multiply;
Expand Down Expand Up @@ -937,7 +941,7 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyGroupConv) {
auto gather = create_gather(weights, {2, 1, 0}, 1);
auto convert = std::make_shared<opset8::Convert>(gather, element::f32);
auto scale = opset8::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset8::Multiply>(convert, scale);
auto multiply = ov::util::get_constant_from_source(std::make_shared<opset8::Multiply>(convert, scale));

auto group_conv = std::make_shared<opset8::GroupConvolution>(data,
multiply,
Expand All @@ -948,7 +952,7 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyGroupConv) {
op::PadType::EXPLICIT);
auto relu = std::make_shared<Relu>(group_conv);
std::shared_ptr<Node> weights2 = opset8::Constant::create(element::f32, Shape{6, 9, 3, 3}, {-2});
auto gather2 = create_gather(weights2, {6, 7, 8, 3, 4, 5, 0, 1, 2}, 1);
auto gather2 = ov::util::get_constant_from_source(create_gather(weights2, {6, 7, 8, 3, 4, 5, 0, 1, 2}, 1));
auto conv = std::make_shared<opset8::Convolution>(relu,
gather2,
ov::Strides{1, 1},
Expand Down Expand Up @@ -1015,7 +1019,7 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyNegative1) {

std::shared_ptr<Node> weights = opset8::Constant::create(element::i8, Shape{4, 3, 1, 1}, {-2});
{
auto gather = create_gather(weights, {2, 1, 0}, 1);
auto gather = ov::util::get_constant_from_source(create_gather(weights, {2, 1, 0}, 1));
auto convert = std::make_shared<opset8::Convert>(gather, element::f32);
auto scale = opset8::Constant::create(element::f32, Shape{1, 1, 1, 1}, {0.2});
auto multiply = std::make_shared<opset8::Multiply>(convert, scale);
Expand Down Expand Up @@ -1088,10 +1092,10 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyNegativeBroadcast) {

std::shared_ptr<Node> weights = opset8::Constant::create(element::i8, Shape{3, 1, 1}, {-2});
{
auto gather = create_gather(weights, {2, 1, 0}, 0);
auto gather = ov::util::get_constant_from_source(create_gather(weights, {2, 1, 0}, 0));
auto convert = std::make_shared<opset8::Convert>(gather, element::f32);
auto scale = opset8::Constant::create(element::f32, Shape{4, 3, 1, 1}, {0.2});
auto gather2 = create_gather(scale, {2, 1, 0}, 1);
auto gather2 = ov::util::get_constant_from_source(create_gather(scale, {2, 1, 0}, 1));
auto multiply = std::make_shared<opset8::Multiply>(convert, gather2);
weights = multiply;
}
Expand Down Expand Up @@ -1174,7 +1178,7 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyNonScalarFQInput) {
create_gather(std::make_shared<opset8::Constant>(element::f32, Shape{1, 3, 14, 14}), {2, 1, 0}, 1);
std::shared_ptr<Node> activations =
std::make_shared<opset8::FakeQuantize>(parameter,
gather,
ov::util::get_constant_from_source(gather),
opset8::Constant::create(element::f32, Shape{}, {20}),
opset8::Constant::create(element::f32, Shape{}, {0}),
opset8::Constant::create(element::f32, Shape{}, {254}),
Expand All @@ -1190,7 +1194,7 @@ TEST_F(TransformationTestsF, RICFusionConvertMultiplyNonScalarFQInput) {
std::shared_ptr<Node> weights = opset8::Constant::create(element::i8, Shape{4, 3, 1, 1}, {-2});
{
auto scale = opset8::Constant::create(element::f32, Shape{}, {0.2});
gather = create_gather(weights, {2, 1, 0}, 1);
auto gather = ov::util::get_constant_from_source(create_gather(weights, {2, 1, 0}, 1));
auto convert = std::make_shared<opset8::Convert>(gather, element::f32);
auto multiply = std::make_shared<opset8::Multiply>(convert, scale);
weights = multiply;
Expand Down
4 changes: 3 additions & 1 deletion src/core/src/pass/constant_folding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "openvino/op/util/read_value_base.hpp"
#include "openvino/op/util/shape_of_base.hpp"
#include "openvino/op/util/sub_graph_base.hpp"
#include "transformations/rt_info/decompression.hpp"
#include "transformations/rt_info/dequantization_node.hpp"

/**
* \brief Check if \ref ov::Output<ov::Node> can be folded base on `can_be_folded` attribute.
Expand Down Expand Up @@ -51,7 +53,7 @@ const auto friendly_name_from = [](const ov::Node& node, const size_t output_cou

static bool restore_original_input_precision(const std::shared_ptr<ov::Node>& node) {
bool restored = false;
if (ov::is_type<ov::op::v0::Convert>(node)) {
if (ov::is_type<ov::op::v0::Convert>(node) && !is_decompression(node) && !is_dequantization_node(node)) {
auto input = node->input(0);
ov::util::remove_original_input_precision_attribute(input);
return restored;
Expand Down
8 changes: 6 additions & 2 deletions src/core/src/preprocess/pre_post_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
#include "itt.hpp"
#include "layout_utils.hpp"
#include "openvino/core/model.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
#include "preprocess_impls.hpp"
#include "transformations/common_optimizations/convolution_to_group_convolution_fusion.hpp"
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
#include "transformations/common_optimizations/disable_random_uniform_constant_folding.hpp"
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
#include "transformations/common_optimizations/ric_fusion.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
#include "openvino/pass/constant_folding.hpp"

namespace {

Expand All @@ -31,6 +32,9 @@ void transformation_pipeline(std::shared_ptr<ov::Model>& model) {
manager.register_pass<MarkDequantization>(TypeVector{i8, u8, i4, u4, nf4});
REGISTER_PASS(manager, DisableShapeOfConstantFolding);
REGISTER_PASS(manager, DisableRandomUniformConstantFolding)
// Mark quantized and f16/bf16 compressed constants to prevent CF for them,
// so that not extra memory is used for intermediate decompressed constants.
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();

REGISTER_PASS(manager, ConvertDivideWithConstant)

Expand Down
42 changes: 38 additions & 4 deletions src/core/tests/preprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ static std::shared_ptr<Model> create_trivial(element::Type type, const PartialSh
return std::make_shared<Model>(ResultVector{res}, ParameterVector{data1});
}

static std::shared_ptr<Model> create_conv(element::Type type, const PartialShape& shape) {
auto data1 = std::make_shared<op::v0::Parameter>(type, shape);
static std::shared_ptr<Model> create_conv(element::Type in_type, const PartialShape& shape, element::Type weight_type) {
auto data1 = std::make_shared<op::v0::Parameter>(in_type, shape);
data1->set_friendly_name("input1");
data1->get_output_tensor(0).set_names({"tensor_input1"});

auto weights = std::make_shared<op::v0::Constant>(type, ov::Shape{1, 3, 3, 3}, 1);
std::shared_ptr<Node> weights = std::make_shared<op::v0::Constant>(weight_type, ov::Shape{1, 3, 3, 3}, 1);
if (weight_type == element::f16) {
// decompression subgraph
weights = std::make_shared<op::v0::Convert>(weights, element::f32);
}
auto conv =
std::make_shared<op::v1::Convolution>(data1, weights, Strides{}, CoordinateDiff{}, CoordinateDiff{}, Strides{});
auto res = std::make_shared<op::v0::Result>(conv);
Expand Down Expand Up @@ -2445,8 +2449,9 @@ TEST(pre_post_process, dump_error) {
TEST_F(TransformationTestsF, preprocessing_mul_conv_fusion) {
auto in_shape = Shape{1, 3, 32, 32};
auto in_type = element::f32;
auto weight_type = element::f32;
{
auto f = create_conv(in_type, in_shape);
auto f = create_conv(in_type, in_shape, weight_type);
auto p = PrePostProcessor(f);

p.input().tensor().set_layout(Layout("NCHW"));
Expand All @@ -2470,3 +2475,32 @@ TEST_F(TransformationTestsF, preprocessing_mul_conv_fusion) {
model_ref = std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{input});
}
}

TEST_F(TransformationTestsF, preprocessing_conv_decompression) {
auto in_shape = Shape{1, 3, 32, 32};
auto in_type = element::f32;
auto weight_type = element::f16;
{
auto f = create_conv(in_type, in_shape, weight_type);
auto p = PrePostProcessor(f);

p.input().tensor().set_layout(Layout("NCHW"));
p.input().preprocess().reverse_channels();
p.input().preprocess().scale(255.);
model = p.build();
}

{
// we expect that MultiplyConvolutionFusion will be applied
auto input = std::make_shared<op::v0::Parameter>(in_type, in_shape);

auto weights = op::v0::Constant::create(weight_type, ov::Shape({1, 3, 3, 3}), {1.f});
auto convert = std::make_shared<op::v0::Convert>(weights, element::f32);
auto B = op::v0::Constant::create(in_type, ov::Shape({1}), {0.003922f});
auto mul = std::make_shared<op::v1::Multiply>(convert, B);
auto conv =
std::make_shared<op::v1::Convolution>(input, mul, Strides{}, CoordinateDiff{}, CoordinateDiff{}, Strides{});
auto res = std::make_shared<op::v0::Result>(conv);
model_ref = std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{input});
}
}

0 comments on commit 0d34dae

Please sign in to comment.