Skip to content

Commit

Permalink
Fixed the comment and added engine cache example
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Feb 26, 2025
1 parent 1a309b8 commit 7cf4ad9
Showing 1 changed file with 71 additions and 5 deletions.
76 changes: 71 additions & 5 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
In this tutorial, we are going to walk through
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
4. Usage of dynamic shape with Mutable Torch TensorRT Module
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
4. Usage of dynamic shape with Mutable Torch TensorRT Module
"""

# %%
import numpy as np
import torch
import torch_tensorrt as torch_trt
Expand Down Expand Up @@ -144,6 +145,12 @@
# %%
# Use Mutable Torch TensorRT module with dynamic shape
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# When adding dynamic shape hint to MutableTorchTensorRTModule, The shape hint should EXACTLY follow the semantics of arg_inputs and kwarg_inputs passed to the forward function
# and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
# If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
# Note that you should exclude keyword arguments with value None as those will be filtered out.


class Model(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -167,7 +174,10 @@ def forward(self, a, b, c={}):
dim_2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
kwarg_dynamic_shapes = {
"c": {"a": {}, "b": {0: dim_2}},
"c": {
"a": {},
"b": {0: dim_2},
}, # a's shape does not change so we give it an empty dict
}
# Export the model first with custom dynamic shape constraints
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
Expand All @@ -181,3 +191,59 @@ def forward(self, a, b, c={}):
}
# Run without recompiling
model(*inputs_2, **kwargs_2)

# %%
# Use Mutable Torch TensorRT module with persistent cache
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Leveraging engine caching, we are able to shortcut the engine compilation and save much time.
import os

from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH

model = models.resnet18(pretrained=True).eval().to("cuda")
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = True

times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)


example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
# Mark the dim0 of inputs as dynamic
model = torch_trt.MutableTorchTensorRTModule(
model,
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
immutable_weights=False,
cache_built_engines=True,
reuse_cached_engines=True,
engine_cache_size=1 << 30, # 1GB
)


def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)


remove_timing_cache()

for i in range(4):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]

start.record()
model(*inputs) # Recompile
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("----------------dynamo_compile----------------")
print("Without engine caching, used:", times[0], "ms")
print("With engine caching used:", times[1], "ms")
print("With engine caching used:", times[2], "ms")
print("With engine caching used:", times[3], "ms")

0 comments on commit 7cf4ad9

Please sign in to comment.