Skip to content

Commit

Permalink
[LPT] integration branch: Reshape fix, Concat generalization, runtime…
Browse files Browse the repository at this point in the history
… info usage extending (openvinotoolkit#2930)

* [LPT] Concat transformation generalization

* [LPT] Reshape transformation fix

* [LPT] Legacy callback fix

* [LPT] * added rt_info propagation
      * functional tests: added rt_info
      * functional tests: added MoveDequatnizationAfter tests

Co-authored-by: Vladislav Golubev <[email protected]>
  • Loading branch information
2 people authored and mryzhov committed Dec 15, 2020
1 parent 8432449 commit e476d35
Show file tree
Hide file tree
Showing 36 changed files with 1,037 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ class TRANSFORMATIONS_API NetworkHelper {
// handles only specific case: Constant -> [dequantization operations] -> [node]
static void foldDequantization(std::shared_ptr<Node>& node, const size_t branchIndex, const bool inPlace = false);

static std::shared_ptr<Node> markAsDequantizationOp(std::shared_ptr<Node> op);

private:
static std::shared_ptr<Node> foldFakeQuantize(const std::shared_ptr<opset1::FakeQuantize>& fq, const bool roundValues, const bool roundValuesWasSet);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
}

newMultiply = NetworkHelper::swapMultiplyAndAdd(add, multiplyBranch.first);

ngraph::copy_runtime_info({ add, newMultiply }, newMultiply);
if (is_type<opset1::Add>(newMultiply->get_input_node_shared_ptr(0))) {
newAddOrSubtract = newMultiply->get_input_node_shared_ptr(0);

Expand Down Expand Up @@ -186,6 +186,7 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter

replace_node(add, newMultiply);
NetworkHelper::copyInfo(add, newAddOrSubtract);
ngraph::copy_runtime_info({ add, newMultiply }, newMultiply);
}

updateOutput(context, newMultiply, newAddOrSubtract);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,15 @@ void ConcatTransformation::addDequantizationLayers(

if (layerDequantizations.size() > 1ul) {
auto broadcastElementWiseConst = [](
// FakeQuantize constant shape must be broadcastable to the shape on data.
std::shared_ptr<ngraph::opset1::Constant> operation,
const ngraph::Shape targetShape) -> std::shared_ptr<Node> {
auto unsqueeze = ngraph::pass::low_precision::fold<ngraph::opset1::Unsqueeze>(
operation->shared_from_this(),
std::make_shared<ngraph::opset1::Constant>(element::i64, ngraph::Shape{ 4 }, std::vector<size_t>{ 0, 1, 2, 3 }));

auto targetShapeConst = std::make_shared<ngraph::opset1::Constant>(
element::i64, ngraph::Shape{ targetShape.size() },
targetShape);

auto broadcast = ngraph::pass::low_precision::fold<ngraph::opset1::Broadcast>(
unsqueeze,
operation,
targetShapeConst,
ngraph::op::AutoBroadcastType::NUMPY);

Expand Down Expand Up @@ -342,6 +339,7 @@ void ConcatTransformation::addDequantizationLayers(
std::shared_ptr<ngraph::Node> convert =
convertNodes[0]->clone_with_new_inputs({ destination->get_input_source_output(sourceOutputIdx) });
insert_new_node_between(source, destination, convert);
ngraph::copy_runtime_info({ layer, convert }, convert);
source = convert;
}

Expand All @@ -354,6 +352,7 @@ void ConcatTransformation::addDequantizationLayers(
subtractNodes[0] :
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1)));
insert_new_node_between(source, destination, subtract);
ngraph::copy_runtime_info({ layer, subtract }, subtract);
source = subtract;
}

Expand All @@ -365,6 +364,7 @@ void ConcatTransformation::addDequantizationLayers(
multiplyNodes[0] :
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(multiplyNodes, 1)));
insert_new_node_between(source, destination, multiply);
ngraph::copy_runtime_info({ layer, multiply }, multiply);
source = multiply;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph

std::shared_ptr<ngraph::opset1::Multiply> finalDequantization = NetworkHelper::optimizeMultipliesAfter(
convolution->output(0).get_target_inputs().begin()->get_node()->shared_from_this());

ngraph::copy_runtime_info({ convolution, finalDequantization }, finalDequantization);
updateOutput(context, finalDequantization, convolution);

auto onWeights = convolution->get_input_node_shared_ptr(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,9 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
fakeQuantize->input_value(4) }));

replace_node(fakeQuantize, newFakeQuantize);
NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);

ngraph::copy_runtime_info({ fakeQuantize, eltwise }, newFakeQuantize);
newFakeQuantize->set_friendly_name(fakeQuantize->get_friendly_name());
NetworkHelper::cleanRunTimeInfo(newFakeQuantize);
return newFakeQuantize;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph
}

if (newOp != nullptr) {
NetworkHelper::copyInfo(op, newOp);
ngraph::copy_runtime_info({ convert, op }, newOp);
newOp->set_friendly_name(op->get_friendly_name());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
matMul->get_transpose_a(),
matMul->get_transpose_b());
NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0));
NetworkHelper::copyInfo(matMul, newMatMul);

auto transpose = [](const std::shared_ptr<Node>& node) -> std::shared_ptr<Node> {
const Shape outputShape = node->get_output_shape(0);
Expand Down Expand Up @@ -95,6 +96,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat
NetworkHelper::toScalar(as_type_ptr<opset1::Constant>(const1)),
const2)));
replace_node(matMul, newMultiply);
ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply);

updateOutput(context, newMultiply, matMul);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter
mvn->get_normalize_variance(),
mvn->get_eps()),
type);
NetworkHelper::copyInfo(mvn, newMVN);

auto newMultiply = std::make_shared<DequantizationMultiply>(newMVN, newScalesConst);
newMVN->set_friendly_name(mvn->get_friendly_name());
ngraph::copy_runtime_info({ mvn, newMultiply }, newMultiply);

replace_node(mvn, newMultiply);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
fq->get_levels(),
fq->get_auto_broadcast()),
true);
newFQ->set_friendly_name(fq->get_friendly_name());
NetworkHelper::copyInfo(fq, newFQ);

std::shared_ptr<ngraph::Node> convert2;
if (updatePrecision) {
Expand All @@ -650,10 +650,12 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos

convert2 = std::make_shared<DequantizationConvert>(convert, element::f32);
convert2->set_friendly_name(convert->get_friendly_name() + "/DequantizationConvert");
ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
} else {
if (newFQ->get_output_element_type(0) != element::f32) {
convert2 = std::make_shared<DequantizationConvert>(newFQ, element::f32);
convert2->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationConvert");
ngraph::copy_runtime_info({ newFQ, convert2 }, convert2);
}
}

Expand All @@ -663,12 +665,14 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
std::make_shared<ngraph::op::TypeRelaxed<DequantizationSubtract>>(convert2 == nullptr ? newFQ : convert2, shift);
if (sub != nullptr) {
sub->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationSubtract");
ngraph::copy_runtime_info({ newFQ, sub }, sub);
}

const std::shared_ptr<ngraph::opset1::Multiply> dequantize = std::make_shared<DequantizationMultiply>(
sub == nullptr ? (convert2 == nullptr ? newFQ : convert2) : sub,
scale);
dequantize->set_friendly_name(newFQ->get_friendly_name() + "/DequantizationMultiply");
ngraph::copy_runtime_info({ newFQ, dequantize }, dequantize);

replace_node(fq, dequantize);

Expand Down Expand Up @@ -929,7 +933,7 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter

const std::shared_ptr<ngraph::Node> newOperation = operation->clone_with_new_inputs(inputs);
newOperation->set_friendly_name(operation->get_friendly_name());
// copyInfo(operation, newOperation);
ngraph::copy_runtime_info(operation, newOperation);

if (updatePrecision) {
auto op = std::dynamic_pointer_cast<ngraph::op::TypeRelaxedBase>(newOperation);
Expand All @@ -945,18 +949,22 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
auto parent = newOperation;
if (shouldConvert) {
parent = std::make_shared<DequantizationConvert>(parent, dequantization.convert->get_output_element_type(0));
ngraph::copy_runtime_info({ newOperation, parent }, parent);
}
if (moveSubtract && (dequantization.subtract != nullptr)) {
auto subtractConstant = dequantization.subtract->get_input_node_shared_ptr(1);
parent = std::make_shared<DequantizationSubtract>(parent, subtractConstant);
ngraph::copy_runtime_info({ newOperation, parent }, parent);
}
if (dequantization.multiply != nullptr) {
auto multiplyConstant = dequantization.multiply->get_input_node_shared_ptr(1);
parent = std::make_shared<DequantizationMultiply>(parent, multiplyConstant);
ngraph::copy_runtime_info({ newOperation, parent }, parent);
}
replace_node(operation, parent);

if ((!moveSubtract) && (dequantization.convert != nullptr) && (dequantization.subtract != nullptr)) {
NetworkHelper::cleanRunTimeInfo(dequantization.subtract);
optimizeSubtract(dequantization.subtract);
}

Expand Down Expand Up @@ -1036,13 +1044,6 @@ std::shared_ptr<Node> NetworkHelper::toScalarIfPossible(std::shared_ptr<Node> no
return NetworkHelper::toScalar(constant);
}

std::shared_ptr<Node> NetworkHelper::markAsDequantizationOp(std::shared_ptr<Node> op) {
auto opCopy = op->clone_with_new_inputs(op->input_values());
auto& rtInfo = opCopy->get_rt_info();
rtInfo["DEQUANTIZATION"] = std::make_shared<VariantWrapper<DequantizationAttr>>(DequantizationAttr());
return opCopy;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ bool NormalizeL2Transformation::transform(TransformationContext &context, ngraph
ngraph::op::TemporaryReplaceOutputType(newScalesConst, element::f32).get());

replace_node(normalize, newMultiply);
ngraph::copy_runtime_info({ normalize, newMultiply }, newMultiply);

updateOutput(context, newMultiply, normalize);
return true;
Expand Down
Loading

0 comments on commit e476d35

Please sign in to comment.