Skip to content

Commit

Permalink
feat: Automatically generate QDP plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
bowang007 committed Feb 25, 2025
1 parent 2d3e06f commit 31bdf77
Show file tree
Hide file tree
Showing 8 changed files with 593 additions and 5 deletions.
151 changes: 151 additions & 0 deletions examples/dynamo/auto_generate_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
.. _auto_generate_converters:
Automatically Generate a Plugin for a Custom Kernel
===================================================================
We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using
the new Python based plugin system in TensorRT 10.7.
Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT
does not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model.
The easiest way to fix lack of support for ops is by adding a decomposition (see:
`Writing lowering passes for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html>`_) - which defines the operator
in terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see:
`Writing converters for the Dynamo frontend <https://pytorch.org/TensorRT/contributors/dynamo_converters.html>`_) - which defines the operator in terms of TensorRT operators.
In some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or
TensorRT cannot support it natively.
For these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding
the performance and resource overhead from a graph break.
Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT <https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/custom_kernel_plugins.html>`_).
With TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This
plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the
operation in PyTorch to TensorRT.
"""

# %%
# Writing Custom Operators in PyTorch
# -----------------------------------------
#
# Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT.
# Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch.
# with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type
# transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it
# is necessary to define.
#

from typing import Tuple

import tensorrt_bindings.plugin as trtp
import torch
import torch_tensorrt
import triton
import triton.language as tl


@triton.jit
def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
# Compute the range of elements that this thread block will work on
block_start = pid * BLOCK_SIZE
# Range of indices this thread will handle
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Load elements from the X and Y tensors
x_vals = tl.load(X + offsets)
y_vals = tl.load(Y + offsets)
# Perform the element-wise multiplication
z_vals = x_vals * y_vals * a + b
# Store the result in Z
tl.store(Z + offsets, z_vals)


@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc]
def elementwise_scale_mul(
X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2
) -> torch.Tensor:
# Ensure the tensors are on the GPU
assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device."
assert X.shape == Y.shape, "Tensors must have the same shape."

# Create output tensor
Z = torch.empty_like(X)

# Define block size
BLOCK_SIZE = 1024

# Grid of programs
grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],)

# Launch the kernel with parameters a and b
elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE)

return Z


# %%
# The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape
# in the course of the operation.


@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul")
def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor:
return x


# %%
# Here we use automatic plugin creation feature in Torch-TensorRT which enables plugin registration using
# TensorRT QDP APIs
torch_tensorrt.dynamo.conversion.plugins.generate_plugin(
"torchtrt_ex::elementwise_scale_mul"
)


# # %%
# # Generating the Converter
# # -------------------------------------------------------------------
# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
# # As long as the namespace and names match, the following function will automatically generate the converter for the operation.
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
"torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True
)


# # %%
# # Above two commands can be replaced with the following single one line:
# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True)


# %%
# Using our converter with a model
# -------------------------------------------------------------------
#
# Now we can use our custom operator in a model and compile it with Torch-TensorRT.
# We can see that the custom operator is used as one of the operations in the forward pass of the model.
# The process of compiling the model at this point is identical to standard Torch-TensorRT usage.
class MyModel(torch.nn.Module): # type: ignore[misc]
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = torch.add(x, y)
res = torch.ops.torchtrt_ex.elementwise_scale_mul.default(x, z, b=0.5)

return res


my_model = MyModel().to("cuda")
m = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)
n = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float)

with torch_tensorrt.logging.errors():
model_trt = torch_tensorrt.compile(
my_model, inputs=[m, n], debug=True, min_block_size=1
)
for i in range(300):
res = model_trt(m, n)
assert torch.allclose(res, my_model(m, n))

print("Ran with custom plugin!")
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import (
generate_plugin_converter,
)
33 changes: 33 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Callable, Optional

from torch.fx.node import Node
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import (
generate_plugin_converter,
)


def custom_op(
op_name: str,
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
) -> None:
"""
Generate the Plugin and corresponding Plugin Converter using external kernels and TensorRT Quick Deployable Plugin APIs.
Args:
plugin_name: the plugin name that is used to generate the plugin automatically.
There should be existing kernels and pytorch custom operation for this plugin name.
capability_validator: A lambda that can take a ``torch.fx.Node`` and determine if the
converter can properly handle this Node. If the validator returns ``False``, the subgraph
partitioner will make sure this Node is run in PyTorch in the compiled graph.
priority: Allows developers to override existing converters in the converter registry
supports_dynamic_shapes: if dynamic shape is supported
"""
generate_plugin(op_name)
generate_plugin_converter(
op_name, capability_validator, priority, supports_dynamic_shapes
)
Loading

0 comments on commit 31bdf77

Please sign in to comment.