Skip to content

Commit

Permalink
[PT FE] Add ModuleExtension (openvinotoolkit#23536)
Browse files Browse the repository at this point in the history
### Details:
 - *Continuation of openvinotoolkit#22867*

### Tickets:
 - *CVS-133733*

---------

Co-authored-by: Sergey Lyalin <[email protected]>
  • Loading branch information
2 people authored and alvoron committed Apr 29, 2024
1 parent 0776362 commit b8c96cb
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.frontend.pytorch.py_pytorch_frontend import ConversionExtensionPytorch as ConversionExtension
from openvino.frontend.pytorch.py_pytorch_frontend import OpExtensionPytorch as OpExtension
from openvino.frontend.pytorch.module_extension import ModuleExtension
except ImportError as err:
raise ImportError("OpenVINO PyTorch frontend is not available, please make sure the frontend is built."
"{}".format(err))
1 change: 0 additions & 1 deletion src/bindings/python/src/openvino/frontend/pytorch/gptq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# flake8: noqa
# mypy: ignore-errors

class ModuleExtension:
def __init__(self, module, target_op, evaluate=None, convert=None):
"""
Creates an extension that replaces entire PyTorch module by a single operation.
This functionality works with PyTorch models only. A module can be identified by
module type (e.g. torch.nn.Linear), module instance in the model or module name.
Args:
module (str, torch.nn.Module, type(torch.nn.Module)): PyTorch module to replace
target_op (str): a target operation that will be used as a replacer for the module,
could be a name of the extension operation or existing PyTorch operation
(with prim:: or aten:: prefix following TorchScript syntax).
evaluate (callable with args module, *args, **kwargs): a callable that will replace a target
module in model execution it is responsible for producing valid output for
the module to allow correct model tracing. By default it calls original module
forward with the same arguments. The provided code will not be a part of the final
traced model, it is used only to produce valid results in the tracing.
convert (callable with args target_op, *args, **kwargs): a callable that will be traced and become
a part of the final model instead of the target module. It accepts target_op as
the first parameter, target_op is callable that will appear as a single node in the
graph, the type of the node is target_op provided as another argument above.
"""
self.module = module
self.target_op = target_op
self.evaluate = evaluate
if self.evaluate is None:
self.evaluate = lambda module, *args, **kwargs: module(*args, **kwargs)
self.convert = convert
if self.convert is None:
self.convert = lambda module, target_op, *args, **kwargs: target_op(*args, **kwargs)
78 changes: 78 additions & 0 deletions src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# flake8: noqa
# mypy: ignore-errors

import torch


class no_jit_trace:
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)

def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None


def patch_model(model, module_extensions, orig_forward_name):
for name, m in model.named_modules():
if hasattr(m, orig_forward_name):
# already patched, skipping with a warning because it is unexpected
print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. '
'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.')
continue
extension = None
if m in module_extensions:
extension = module_extensions[m]
elif m.__class__ in module_extensions:
extension = module_extensions[m.__class__]
elif name in module_extensions:
extension = module_extensions[name]

if extension:
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.
class Trampoline(torch.autograd.Function):
target_extension = extension
original_module = m
stashed_args = None
stashed_kwargs = None

@staticmethod
@torch.jit.ignore
def forward(*args, **kwargs):
with no_jit_trace():
# `module` is going to be passed to a user-defined function `evaluate`
# `module` is patched: forward function was replaced, and we are actually in this patched function right in this code
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail
# so we need to temporary patch the module back to the original forward and then return it back again
# stash the current forward to be able to return it back
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(
m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs)
m.forward = patched_forward # return patched forward back
return results

def new_forward(*args, **kwargs):
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
return extension.convert(m, Trampoline.apply, *args, **kwargs)
setattr(m, orig_forward_name, m.forward)
m.forward = new_forward


def unpatch_model(model, orig_forward_name):
for _, m in model.named_modules():
if hasattr(m, orig_forward_name):
try:
m.forward = getattr(m, orig_forward_name)
delattr(m, orig_forward_name)
except Exception as error:
print('[ WARNING ] Exception raised during model unpatching. Depending on the exact issue it may lead to broken original model.')
print('Original exception details:')
print(error)
46 changes: 42 additions & 4 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,32 @@
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor, graph_has_ops
from openvino.runtime import opset11 as ops
from openvino.frontend.pytorch import gptq
from openvino.frontend.pytorch import patch_model
from openvino.frontend.pytorch.module_extension import ModuleExtension

import typing
import torch


class TorchScriptPythonDecoder (Decoder):
def __init__(self, pt_module, graph_element=None, example_input=None, alias_db=None, shared_memory=True, skip_freeze=False, constant_cache=None):
def __init__(
self,
pt_module,
graph_element=None,
example_input=None,
alias_db=None,
shared_memory=True,
skip_freeze=False,
constant_cache=None,
module_extensions=None):
Decoder.__init__(self)
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
self.m_decoders = []
self._input_signature = None
self._shared_memory = shared_memory
self._input_is_list = False
self.constant_cache = constant_cache if constant_cache is not None else dict()
self.module_extensions = module_extensions
if graph_element is None:
try:
pt_module = self._get_scripted_model(
Expand Down Expand Up @@ -89,14 +101,22 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
input_params = inspect.signature(pt_module.forward if hasattr(
pt_module, "forward") else pt_module.__call__).parameters
input_signature = list(input_params)

if example_inputs is None:
if self.module_extensions:
raise RuntimeError("ModuleExtension is not supported for scripting. Please provide valid example_input argument to run tracing.")
scripted = torch.jit.script(pt_module)
freeze_by_default = True
else:
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
example_inputs, input_params, pt_module)
gptq_patched = False

# name of attribute in a patched module where the original forward method is kept
orig_forward_name = '_openvino_module_extension_patch_orig_forward'
if self.module_extensions:
patch_model.patch_model(pt_module, self.module_extensions, orig_forward_name)

gptq_patched = False
if gptq.detect_gptq_model(pt_module):
try:
gptq.patch_model(pt_module)
Expand All @@ -115,6 +135,8 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
finally:
if gptq_patched:
gptq.unpatch_model(pt_module)
if self.module_extensions:
patch_model.unpatch_model(pt_module, orig_forward_name)

if not freeze_by_default and graph_has_ops(scripted.inlined_graph, ["prim::Uninitialized", "prim::unchecked_cast", "aten::append"]):
# freeze models with unsupported ops
Expand Down Expand Up @@ -232,7 +254,8 @@ def visit_subgraph(self, node_visitor) -> None:
node,
alias_db=self.alias_db,
shared_memory=self._shared_memory,
constant_cache=self.constant_cache)
constant_cache=self.constant_cache,
module_extensions=self.module_extensions)
self.m_decoders.append(decoder)
node_visitor(decoder)

Expand All @@ -255,13 +278,28 @@ def get_subgraph_decoder(self, index: int):
decoder = TorchScriptPythonDecoder(self.pt_module,
self.get_subgraphs()[index],
alias_db=self.alias_db,
shared_memory=self._shared_memory)
shared_memory=self._shared_memory,
module_extensions=self.module_extensions)
self.m_decoders.append(decoder)
return decoder

def get_op_type(self) -> str:
assert isinstance(
self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
if self.graph_element.kind() == "prim::PythonOp":
if hasattr(self.graph_element, 'pyobj') and callable(self.graph_element.pyobj) and hasattr(self.graph_element.pyobj(), '__self__'):
trampoline = self.graph_element.pyobj().__self__
if hasattr(trampoline, 'target_extension') and isinstance(trampoline.target_extension, ModuleExtension):
target_op = trampoline.target_extension.target_op
if callable(target_op):
target = target_op(trampoline.original_module)
elif isinstance(target_op, str):
target = target_op
# TODO: Support target as a callable that will play a role of ConversionExtension for an entire module instead of a single op.
# Without supporting target as a callable here, ConversionExtension functionality is still possible to implement
# by combining two extensions: ModuleExtension that use temporary name as a target op and another extension of type ConversionExtension
# that translates that particular temporary name to custom graph. But providing conversion code as a callable `target` is more convenient.
return target
return self.graph_element.kind()

def get_schema(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion src/bindings/python/src/openvino/frontend/pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& va
}

std::map<std::string, CreatorFunction> FrontEnd::get_supported_ops(const ov::frontend::InputModel::Ptr& model) const {
std::map<std::string, CreatorFunction> supported_ops = get_supported_ops_fx();
std::map<std::string, CreatorFunction> supported_ops;
if (std::dynamic_pointer_cast<pytorch::InputModel>(model)->decoder_type_name() == "fx")
supported_ops = get_supported_ops_fx();
else
Expand Down
74 changes: 54 additions & 20 deletions tests/layer_tests/py_frontend_tests/test_torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,17 @@ def forward(self, x):
"Parameter", "ReluCustom", "Result"]


def test_op_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import OpExtension

class CosModel(torch.nn.Module):
def __init__(self):
class CosModel(torch.nn.Module):
def __init__(self):
super(CosModel, self).__init__()

def forward(self, x):
def forward(self, x):
return torch.cos(x.to(torch.float32))

def test_op_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import OpExtension

model = CosModel()
decoder = TorchScriptPythonDecoder(get_scripted_model(model))

Expand All @@ -327,13 +327,6 @@ def test_op_extension_generic():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend import OpExtension

class CosModel(torch.nn.Module):
def __init__(self):
super(CosModel, self).__init__()

def forward(self, x):
return torch.cos(x.to(torch.float32))

model = CosModel()
decoder = TorchScriptPythonDecoder(get_scripted_model(model))

Expand All @@ -355,6 +348,49 @@ def forward(self, x):
"Parameter", "Convert", "Sin", "Result"]


def test_module_extension():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from openvino.frontend.pytorch import ModuleExtension
from openvino import convert_model

class ModelWithModule(torch.nn.Module):
def __init__(self):
super(ModelWithModule, self).__init__()
self.cos_module = CosModel()

def forward(self, x):
return self.cos_module(x)

model = ModelWithModule()
decoder = TorchScriptPythonDecoder(model)

fem = FrontEndManager()
fe = fem.load_by_framework(framework="pytorch")
assert fe

input_model = fe.load(decoder)
assert input_model
converted_model = fe.convert(input_model)
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Convert", "Cos", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension(model.cos_module, "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]

converted_model = convert_model(model, example_input=(torch.randn(100),), extension=[ModuleExtension("cos_module", "aten::sin")])
assert converted_model
assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [
"Parameter", "Sin", "Result"]


def test_pytorch_telemetry():
from openvino.frontend import TelemetryExtension
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
Expand Down Expand Up @@ -547,7 +583,7 @@ def forward(self, x: float, y: torch.Tensor):
assert PartialShape(pt_out_shape) == om.get_output_partial_shape(0)


class TestModel1(torch.nn.Module):
class ModelTest1(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.pool = torch.nn.AdaptiveAvgPool2d(1)
Expand All @@ -559,8 +595,7 @@ def forward(self, x):
def test_output_dict_names():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder

input = torch.ones((1, 3, 224, 224))
model = TestModel1()
model = ModelTest1()
decoder = TorchScriptPythonDecoder(
model, example_input=(torch.randn(1, 3, 224, 224),))
fe_manager = FrontEndManager()
Expand All @@ -570,7 +605,7 @@ def test_output_dict_names():
assert om.outputs[0].any_name == "x1" and om.outputs[1].any_name == "x2", "Output dict names are not expected"


class TestModel2(torch.nn.Module):
class ModelTest2(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.pool = torch.nn.AdaptiveAvgPool2d(1)
Expand All @@ -582,8 +617,7 @@ def forward(self, x):
def test_output_tuple_names():
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder

input = torch.ones((1, 3, 224, 224))
model = TestModel2()
model = ModelTest2()
decoder = TorchScriptPythonDecoder(
model, example_input=(torch.randn(1, 3, 224, 224),))
fe_manager = FrontEndManager()
Expand Down
4 changes: 3 additions & 1 deletion tools/ovc/openvino/tools/ovc/convert_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

# pylint: disable=no-name-in-module,import-error
from openvino.frontend import FrontEndManager, OpConversionFailure, TelemetryExtension
from openvino.frontend.pytorch.module_extension import ModuleExtension
from openvino.runtime import get_version as get_rt_version
from openvino.runtime import Type, PartialShape

Expand Down Expand Up @@ -173,7 +174,8 @@ def prepare_ir(argv: argparse.Namespace):
moc_front_end.add_extension(TelemetryExtension("ovc", t.send_event, t.send_error, t.send_stack_trace))
if any_extensions_used(argv):
for extension in argv.extension:
moc_front_end.add_extension(extension)
if not isinstance(extension, ModuleExtension):
moc_front_end.add_extension(extension)
ov_model = moc_pipeline(argv, moc_front_end)
return ov_model

Expand Down
Loading

0 comments on commit b8c96cb

Please sign in to comment.