Skip to content

Commit

Permalink
Mutable module improvement (#3394)
Browse files Browse the repository at this point in the history
Co-authored-by: Chengzhe Xu <[email protected]>
Co-authored-by: Dheeraj Peri <[email protected]>
  • Loading branch information
3 people authored Feb 28, 2025
1 parent b33f393 commit f4219f7
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 66 deletions.
2 changes: 1 addition & 1 deletion examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ Model Zoo
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
160 changes: 144 additions & 16 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
In this tutorial, we are going to walk through
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
2. Save a Mutable Torch TensorRT Module
3. Integration with Huggingface pipeline in LoRA use case
4. Usage of dynamic shape with Mutable Torch TensorRT Module
"""

# %%
import numpy as np
import torch
import torch_tensorrt as torch_trt
Expand Down Expand Up @@ -63,16 +65,14 @@
# Saving Mutable Torch TensorRT Module
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Currently, saving is only enabled for C++ runtime, not python runtime.
# Currently, saving is only enabled when "use_python_runtime" = False in settings
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")

# %%
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# The LoRA checkpoint is from https://civitai.com/models/12597/moxin

from diffusers import DiffusionPipeline

with torch.no_grad():
Expand All @@ -83,33 +83,161 @@
"immutable_weights": False,
}

model_id = "runwayml/stable-diffusion-v1-5"
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
device = "cuda:0"

prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
prompt = "cinematic photo elsa, police uniform <lora:princess_xl_v2:0.8>, . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
negative = "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, nude"

pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.to(device)

# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
BATCH = torch.export.Dim("BATCH", min=2, max=24)
_HEIGHT = torch.export.Dim("_HEIGHT", min=16, max=32)
_WIDTH = torch.export.Dim("_WIDTH", min=16, max=32)
HEIGHT = 4 * _HEIGHT
WIDTH = 4 * _WIDTH
args_dynamic_shapes = ({0: BATCH, 2: HEIGHT, 3: WIDTH}, {})
kwargs_dynamic_shapes = {
"encoder_hidden_states": {0: BATCH},
"added_cond_kwargs": {
"text_embeds": {0: BATCH},
"time_ids": {0: BATCH},
},
}
pipe.unet.set_expected_dynamic_shape_range(
args_dynamic_shapes, kwargs_dynamic_shapes
)
image = pipe(
prompt,
negative_prompt=negative,
num_inference_steps=30,
height=1024,
width=768,
num_images_per_prompt=2,
).images[0]
image.save("./without_LoRA_mutable.jpg")

# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="moxin.safetensors",
weight_name="all-disney-princess-xl-lo.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()

# Refit triggered
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image = pipe(
prompt,
negative_prompt=negative,
num_inference_steps=30,
height=1024,
width=1024,
num_images_per_prompt=1,
).images[0]
image.save("./with_LoRA_mutable.jpg")


# %%
# Use Mutable Torch TensorRT module with dynamic shape
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# When adding dynamic shape hint to MutableTorchTensorRTModule, The shape hint should EXACTLY follow the semantics of arg_inputs and kwarg_inputs passed to the forward function
# and should not omit any entries (except None in the kwarg_inputs). If there is a nested dict/list in the input, the dynamic shape for that entry should also be an nested dict/list.
# If the dynamic shape is not required for an input, an empty dictionary should be given as the shape hint for that input.
# Note that you should exclude keyword arguments with value None as those will be filtered out.


class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c={}):
x = torch.matmul(a, b)
x = torch.matmul(c["a"], c["b"].T)
print(c["b"][0])
x = 2 * c["b"]
return x


device = "cuda:0"
model = Model().eval().to(device)
inputs = (torch.rand(10, 3).to(device), torch.rand(3, 30).to(device))
kwargs = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(10, 30).to(device)},
}
dim_0 = torch.export.Dim("dim", min=1, max=50)
dim_1 = torch.export.Dim("dim", min=1, max=50)
dim_2 = torch.export.Dim("dim2", min=1, max=50)
args_dynamic_shapes = ({1: dim_1}, {0: dim_0})
kwarg_dynamic_shapes = {
"c": {
"a": {},
"b": {0: dim_2},
}, # a's shape does not change so we give it an empty dict
}
# Export the model first with custom dynamic shape constraints
model = torch_trt.MutableTorchTensorRTModule(model, debug=True, min_block_size=1)
model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwarg_dynamic_shapes)
# Compile
model(*inputs, **kwargs)
# Change input shape
inputs_2 = (torch.rand(10, 5).to(device), torch.rand(10, 30).to(device))
kwargs_2 = {
"c": {"a": torch.rand(10, 30).to(device), "b": torch.rand(5, 30).to(device)},
}
# Run without recompiling
model(*inputs_2, **kwargs_2)

# %%
# Use Mutable Torch TensorRT module with persistent cache
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Leveraging engine caching, we are able to shortcut the engine compilation and save much time.
import os

from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH

model = models.resnet18(pretrained=True).eval().to("cuda")

times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
model = torch_trt.MutableTorchTensorRTModule(
model,
use_python_runtime=True,
enabled_precisions={torch.float},
debug=True,
min_block_size=1,
immutable_weights=False,
cache_built_engines=True,
reuse_cached_engines=True,
engine_cache_size=1 << 30, # 1GB
)


def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)


remove_timing_cache()

for i in range(4):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]

start.record()
model(*inputs) # Recompile
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("----------------dynamo_compile----------------")
print("Without engine caching, used:", times[0], "ms")
print("With engine caching used:", times[1], "ms")
print("With engine caching used:", times[2], "ms")
print("With engine caching used:", times[3], "ms")
9 changes: 6 additions & 3 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,12 @@ def refit_module_weights(
try:
weight_name_map = compiled_submodule.weight_name_map
except AttributeError:
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not isinstance(
compiled_submodule, torch.fx.graph_module.GraphModule
):
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
Expand Down
19 changes: 12 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,10 @@ def _construct_trt_network_def(self) -> None:

@staticmethod
def find_weight(
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
weight_name: str,
np_map: dict[str, Any],
state_dict: dict[str, Any],
device: torch.device,
) -> str:
"""
We need to build map from engine weight name to state_dict weight name.
Expand All @@ -385,19 +388,21 @@ def find_weight(
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
sd_weight: torch.tensor,
network_weight: Union[torch.Tensor, np.ndarray],
device: torch.device,
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).cuda()
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
Expand Down Expand Up @@ -530,10 +535,10 @@ def _save_weight_mapping(self) -> None:
# There is no direct connection in batch_norm layer. So skip it
pass
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
sd[sd_weight_name], np_map[engine_weight_name]
sd[sd_weight_name], np_map[engine_weight_name], torch_device
):
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
engine_weight_name, np_map, sd
engine_weight_name, np_map, sd, torch_device
)
if (
weight_name_map[engine_weight_name] != ""
Expand Down
Loading

0 comments on commit f4219f7

Please sign in to comment.