Skip to content

Commit

Permalink
fix: Split addmm nodes to not cast bias for FP32 accumulation and flu…
Browse files Browse the repository at this point in the history
…x example fixes. (#3395)

Co-authored-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 and Dheeraj Peri authored Feb 25, 2025
1 parent a3db469 commit 7ab637e
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 12 deletions.
18 changes: 12 additions & 6 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
**FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications.
Install the following dependencies before compilation
To run this demo, you need to have access to Flux model (request for access if you do not have it already on the `FLUX.1-dev <https://huggingface.co/black-forest-labs/FLUX.1-dev>`_ page) and install the following dependencies
.. code-block:: python
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2"
pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3"
There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example,
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
Expand All @@ -38,11 +38,10 @@
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
pipe.to(DEVICE).to(torch.float16)

# Store the config and transformer backbone
config = pipe.transformer.config
backbone = pipe.transformer

backbone = pipe.transformer.to(DEVICE)

# %%
# Export the backbone using torch.export
Expand All @@ -63,6 +62,8 @@
"txt_ids": {0: SEQ_LEN},
"img_ids": {0: IMG_ID},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}
# The guidance factor is of type torch.float32
dummy_inputs = {
Expand All @@ -79,6 +80,8 @@
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
"guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE),
"joint_attention_kwargs": {},
"return_dict": False,
}
# This will create an exported program which is going to be compiled with Torch-TensorRT
ep = _export(
Expand Down Expand Up @@ -116,8 +119,11 @@
# ---------------------------
# Release the GPU memory occupied by the exported program and the pipe.transformer
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model
backbone.to("cpu")

del ep
backbone.to("cpu")
pipe.to(DEVICE)
torch.cuda.empty_cache()
pipe.transformer = trt_gm
pipe.transformer.config = config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
from .remove_assert_scalar import remove_assert_scalar
from .remove_assert_nodes import remove_assert_nodes
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
Expand All @@ -27,7 +27,7 @@
replace_max_pool_with_indices,
lower_scaled_dot_product_attention,
view_to_reshape,
remove_assert_scalar,
remove_assert_nodes,
accumulate_fp32_matmul,
]
)
Expand Down
41 changes: 39 additions & 2 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,54 @@
logger = logging.getLogger(__name__)


def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
target = torch.ops.aten.addmm.default
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
for addmm_node in addmm_nodes:
bias, mat1, mat2 = addmm_node.all_input_nodes
beta = addmm_node.kwargs.get("beta")
alpha = addmm_node.kwargs.get("alpha")

with gm.graph.inserting_before(addmm_node):
mm_node = gm.graph.call_function(
torch.ops.aten.mm.default,
args=(mat1, mat2),
)
if alpha:
mm_node = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(mm_node, alpha),
)

if beta:
bias = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(bias, beta),
)
add_node = gm.graph.call_function(
torch.ops.aten.add.Tensor,
args=(bias, mm_node),
)

addmm_node.replace_all_uses_with(add_node, propagate_meta=True)
gm.graph.erase_node(addmm_node)

return gm


def accumulate_fp32_matmul(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace a matmul layer with fp32 accumulation nodes"""
"""Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution."""
if settings.use_fp32_acc:
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.addmm.default,
]

# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes
split_addmm_nodes(gm)

matmul_nodes = [
node for node in gm.graph.nodes if node.target in matmul_targets
]
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def constant_fold(
gm.graph.erase_node(node)

gm = clean_up_graph_after_modifications(gm)
# Delete the constant folder instance which holds GPU memory
del cf

logger.debug(f"Graph after constant folding:\n{gm.graph}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
logger = logging.getLogger(__name__)


def remove_assert_scalar(
def remove_assert_nodes(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Remove assert_scalar ops in the graph"""
count = 0
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._assert_scalar.default
or node == torch.ops.aten._assert_tensor_metadata.default
or node.target == torch.ops.aten._assert_tensor_metadata.default
):
gm.graph.erase_node(node)
count += 1
Expand Down
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import gc
import logging
import warnings
from dataclasses import fields, replace
Expand Down Expand Up @@ -30,6 +31,7 @@
DYNAMIC_DIM = -1
RTOL = 5e-3
ATOL = 5e-3
CPU_DEVICE = "cpu"


class Frameworks(Enum):
Expand Down Expand Up @@ -81,6 +83,17 @@ class Frameworks(Enum):
}


def delete_module(module: torch.fx.GraphModule) -> None:
"""
This is a helper function to delete the instance of module. We first move it to CPU and then
delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call
"""
module.to(CPU_DEVICE)
del module
torch.cuda.empty_cache()
gc.collect()


def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
"""Parses a user-provided input argument regarding Python runtime
Expand Down
42 changes: 42 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,48 @@ def forward(self, input, weight):
)
torch._dynamo.reset()

def test_fp32_acc_for_addmm(self):
class FP32Acc(torch.nn.Module):
def forward(self, input, mat1, mat2):
out = torch.ops.aten.addmm.default(input, mat1, mat2, beta=20, alpha=2)
return out

inputs = [
torch.rand((3, 5)).cuda(),
torch.rand((3, 4)).cuda(),
torch.rand((4, 5)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(FP32Acc())
expected_ops = {
torch.ops.aten._to_copy.default,
torch.ops.aten.mm.default,
torch.ops.aten.add.Tensor,
}
unexpected_ops = {}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
use_fp32_acc=True,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()


class TestLowerEfficientAttention(TestCase):
def test_lower_efficient_attention(self):
Expand Down

0 comments on commit 7ab637e

Please sign in to comment.