-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PT FE] Add ModuleExtension (#23536)
### Details: - *Continuation of #22867* ### Tickets: - *CVS-133733* --------- Co-authored-by: Sergey Lyalin <[email protected]>
- Loading branch information
Showing
10 changed files
with
231 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
|
||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
39 changes: 39 additions & 0 deletions
39
src/bindings/python/src/openvino/frontend/pytorch/module_extension.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# flake8: noqa | ||
# mypy: ignore-errors | ||
|
||
class ModuleExtension: | ||
def __init__(self, module, target_op, evaluate=None, convert=None): | ||
""" | ||
Creates an extension that replaces entire PyTorch module by a single operation. | ||
This functionality works with PyTorch models only. A module can be identified by | ||
module type (e.g. torch.nn.Linear), module instance in the model or module name. | ||
Args: | ||
module (str, torch.nn.Module, type(torch.nn.Module)): PyTorch module to replace | ||
target_op (str): a target operation that will be used as a replacer for the module, | ||
could be a name of the extension operation or existing PyTorch operation | ||
(with prim:: or aten:: prefix following TorchScript syntax). | ||
evaluate (callable with args module, *args, **kwargs): a callable that will replace a target | ||
module in model execution it is responsible for producing valid output for | ||
the module to allow correct model tracing. By default it calls original module | ||
forward with the same arguments. The provided code will not be a part of the final | ||
traced model, it is used only to produce valid results in the tracing. | ||
convert (callable with args target_op, *args, **kwargs): a callable that will be traced and become | ||
a part of the final model instead of the target module. It accepts target_op as | ||
the first parameter, target_op is callable that will appear as a single node in the | ||
graph, the type of the node is target_op provided as another argument above. | ||
""" | ||
self.module = module | ||
self.target_op = target_op | ||
self.evaluate = evaluate | ||
if self.evaluate is None: | ||
self.evaluate = lambda module, *args, **kwargs: module(*args, **kwargs) | ||
self.convert = convert | ||
if self.convert is None: | ||
self.convert = lambda module, target_op, *args, **kwargs: target_op(*args, **kwargs) |
78 changes: 78 additions & 0 deletions
78
src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# flake8: noqa | ||
# mypy: ignore-errors | ||
|
||
import torch | ||
|
||
|
||
class no_jit_trace: | ||
def __enter__(self): | ||
self.state = torch._C._get_tracing_state() | ||
torch._C._set_tracing_state(None) | ||
|
||
def __exit__(self, *args): | ||
torch._C._set_tracing_state(self.state) | ||
self.state = None | ||
|
||
|
||
def patch_model(model, module_extensions, orig_forward_name): | ||
for name, m in model.named_modules(): | ||
if hasattr(m, orig_forward_name): | ||
# already patched, skipping with a warning because it is unexpected | ||
print(f'[ WARNING ] Unexpectedly found already patched module {name} while applying ModuleExtension during PyTorch model conversion. ' | ||
'Result of the conversion maybe broken. Depending on the exact issue it may lead to broken original model.') | ||
continue | ||
extension = None | ||
if m in module_extensions: | ||
extension = module_extensions[m] | ||
elif m.__class__ in module_extensions: | ||
extension = module_extensions[m.__class__] | ||
elif name in module_extensions: | ||
extension = module_extensions[name] | ||
|
||
if extension: | ||
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module. | ||
class Trampoline(torch.autograd.Function): | ||
target_extension = extension | ||
original_module = m | ||
stashed_args = None | ||
stashed_kwargs = None | ||
|
||
@staticmethod | ||
@torch.jit.ignore | ||
def forward(*args, **kwargs): | ||
with no_jit_trace(): | ||
# `module` is going to be passed to a user-defined function `evaluate` | ||
# `module` is patched: forward function was replaced, and we are actually in this patched function right in this code | ||
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail | ||
# so we need to temporary patch the module back to the original forward and then return it back again | ||
# stash the current forward to be able to return it back | ||
patched_forward = m.forward | ||
# set original forward for the module | ||
m.forward = getattr(m, orig_forward_name) | ||
# call user code | ||
results = extension.evaluate( | ||
m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs) | ||
m.forward = patched_forward # return patched forward back | ||
return results | ||
|
||
def new_forward(*args, **kwargs): | ||
Trampoline.stashed_args = args | ||
Trampoline.stashed_kwargs = kwargs | ||
return extension.convert(m, Trampoline.apply, *args, **kwargs) | ||
setattr(m, orig_forward_name, m.forward) | ||
m.forward = new_forward | ||
|
||
|
||
def unpatch_model(model, orig_forward_name): | ||
for _, m in model.named_modules(): | ||
if hasattr(m, orig_forward_name): | ||
try: | ||
m.forward = getattr(m, orig_forward_name) | ||
delattr(m, orig_forward_name) | ||
except Exception as error: | ||
print('[ WARNING ] Exception raised during model unpatching. Depending on the exact issue it may lead to broken original model.') | ||
print('Original exception details:') | ||
print(error) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
|
||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.