From b60c4c69562ffe59da0675605426806f7521d15c Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 25 Feb 2025 22:26:27 +0000 Subject: [PATCH] feat: Automatically generate QDP plugins --- examples/dynamo/auto_generate_plugin.py | 151 +++++++++++++ .../dynamo/conversion/plugins/__init__.py | 2 + .../dynamo/conversion/plugins/_custom_op.py | 33 +++ .../conversion/plugins/_generate_plugin.py | 213 ++++++++++++++++++ .../plugins/_generate_plugin_converter.py | 20 +- pyproject.toml | 1 + .../conversion/test_automatic_plugin.py | 92 ++++++++ .../test_automatic_plugin_with_attrs.py | 86 +++++++ 8 files changed, 593 insertions(+), 5 deletions(-) create mode 100644 examples/dynamo/auto_generate_plugin.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py create mode 100644 tests/py/dynamo/conversion/test_automatic_plugin.py create mode 100644 tests/py/dynamo/conversion/test_automatic_plugin_with_attrs.py diff --git a/examples/dynamo/auto_generate_plugin.py b/examples/dynamo/auto_generate_plugin.py new file mode 100644 index 0000000000..2ea50b87f6 --- /dev/null +++ b/examples/dynamo/auto_generate_plugin.py @@ -0,0 +1,151 @@ +""" +.. _auto_generate_converters: + +Automatically Generate a Plugin for a Custom Kernel +=================================================================== + +We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using +the new Python based plugin system in TensorRT 10.7. + +Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT +does not know how to compile them in TensorRT. However, this comes at the cost of a graph break and will reduce the performance of the model. +The easiest way to fix lack of support for ops is by adding a decomposition (see: +`Writing lowering passes for the Dynamo frontend `_) - which defines the operator +in terms of PyTorch ops that are supported in Torch-TensorRT or a converter (see: +`Writing converters for the Dynamo frontend `_) - which defines the operator in terms of TensorRT operators. + +In some cases there isn't a great way to do either of these, perhaps because the operator is a custom kernel that is not part of standard PyTorch or +TensorRT cannot support it natively. + +For these cases, it is possible to use a TensorRT plugin to replace the operator **inside** the TensorRT engine, thereby avoiding +the performance and resource overhead from a graph break. + +Previously this involved a complex process in not only building a performant kernel but setting it up to run in TensorRT (see: `Using Custom Kernels within TensorRT Engines with Torch-TensorRT `_). +With TensorRT 10.7, there is a new Python native plugin system which greatly streamlines this process. This +plugin system also allows Torch-TensorRT to automatically generate the necessary conversion code to convert the +operation in PyTorch to TensorRT. +""" + +# %% +# Writing Custom Operators in PyTorch +# ----------------------------------------- +# +# Pervious tutorials already cover creating custom operators in PyTorch which later get used with Torch-TensorRT. +# Here we define a simple elementwise multiplication operator in Triton. This operator is then registered as a custom op in PyTorch. +# with its host launch code as well as a "meta-kernel", A meta-kernel is a function that describes the shape and data type +# transformations that the operator will perform. This meta-kernel is used by Dynamo and Torch-TensorRT, so it +# is necessary to define. +# + +from typing import Tuple + +import tensorrt_bindings.plugin as trtp +import torch +import torch_tensorrt +import triton +import triton.language as tl + + +@triton.jit +def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals * y_vals * a + b + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc] +def elementwise_scale_mul( + X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2 +) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],) + + # Launch the kernel with parameters a and b + elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE) + + return Z + + +# %% +# The meta kernel for an elementwise operation is just the shape and dtype of one of the inputs since we will not change the shape +# in the course of the operation. + + +@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul") +def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor: + return x + + +# %% +# Here we use automatic plugin creation feature in Torch-TensorRT which enables plugin registration using +# TensorRT QDP APIs +torch_tensorrt.dynamo.conversion.plugins.generate_plugin( + "torchtrt_ex::elementwise_scale_mul" +) + + +# # %% +# # Generating the Converter +# # ------------------------------------------------------------------- +# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation. +# # As long as the namespace and names match, the following function will automatically generate the converter for the operation. +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( + "torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True +) + + +# # %% +# # Above two commands can be replaced with the following single one line: +# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True) + + +# %% +# Using our converter with a model +# ------------------------------------------------------------------- +# +# Now we can use our custom operator in a model and compile it with Torch-TensorRT. +# We can see that the custom operator is used as one of the operations in the forward pass of the model. +# The process of compiling the model at this point is identical to standard Torch-TensorRT usage. +class MyModel(torch.nn.Module): # type: ignore[misc] + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + z = torch.add(x, y) + res = torch.ops.torchtrt_ex.elementwise_scale_mul.default(x, z, b=0.5) + + return res + + +my_model = MyModel().to("cuda") +m = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float) +n = torch.randint(0, 5, (64, 64), device="cuda", dtype=torch.float) + +with torch_tensorrt.logging.errors(): + model_trt = torch_tensorrt.compile( + my_model, inputs=[m, n], debug=True, min_block_size=1 + ) + for i in range(300): + res = model_trt(m, n) + assert torch.allclose(res, my_model(m, n)) + +print("Ran with custom plugin!") diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py index 379a39943e..fc5e973560 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py @@ -1,3 +1,5 @@ +from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op +from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import ( generate_plugin_converter, ) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py b/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py new file mode 100644 index 0000000000..ef5ed59a56 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py @@ -0,0 +1,33 @@ +from typing import Callable, Optional + +from torch.fx.node import Node +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority +from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin +from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import ( + generate_plugin_converter, +) + + +def custom_op( + op_name: str, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, +) -> None: + """ + Generate the Plugin and corresponding Plugin Converter using external kernels and TensorRT Quick Deployable Plugin APIs. + + Args: + plugin_name: the plugin name that is used to generate the plugin automatically. + There should be existing kernels and pytorch custom operation for this plugin name. + capability_validator: A lambda that can take a ``torch.fx.Node`` and determine if the + converter can properly handle this Node. If the validator returns ``False``, the subgraph + partitioner will make sure this Node is run in PyTorch in the compiled graph. + priority: Allows developers to override existing converters in the converter registry + supports_dynamic_shapes: if dynamic shape is supported + """ + generate_plugin(op_name) + generate_plugin_converter( + op_name, capability_validator, priority, supports_dynamic_shapes + ) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py new file mode 100644 index 0000000000..4211bae1fa --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -0,0 +1,213 @@ +import logging +from types import FunctionType +from typing import Any, Callable, Tuple + +import tensorrt.plugin as trtp +import torch +from sympy import lambdify +from torch._dynamo.source import LocalSource +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def mksym( + shape_env: ShapeEnv, value: int, source: LocalSource, dynamic_dim: DimDynamic +) -> torch.SymInt: + return shape_env.create_symintnode( + shape_env.create_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + + +def _generate_plugin(plugin_name: str) -> None: + namespace, name = plugin_name.split("::") + + # retrieve the corresponding torch operation using the passed in string + torch_op = getattr(getattr(torch.ops, namespace), name) + + # helper function that generates the required signature based on the torch operation + def generate_signature( + torch_op: Callable[[Any], Any], + ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: + schema = torch_op._schemas[""] + + arg_list = [] + + register_func_annotation = {} + impl_func_annotation = {} + + for arg in schema.arguments: + arg_list.append(arg.name) + + # TODO: Torch types need to be converted to python primitive types here + # Some other types are not handled: + # - torch._C.ListType.ofT() + # - torch._C.TupleType.get() + # - torch._C.DictType.get(, ) + # - torch._C.OptionalType.ofT() + # - torch._C.DeviceObjType.get() + # - torch._C.FunctionType.get() + # - torch._C.ClassType + + if arg.type.isSubtypeOf(torch._C.TensorType.get()): + register_func_annotation[arg.name] = trtp.TensorDesc + impl_func_annotation[arg.name] = trtp.Tensor + elif arg.type.isSubtypeOf(torch._C.FloatType.get()): + register_func_annotation[arg.name] = float + impl_func_annotation[arg.name] = float + elif arg.type.isSubtypeOf(torch._C.IntType.get()): + register_func_annotation[arg.name] = int + impl_func_annotation[arg.name] = int + elif arg.type.isSubtypeOf(torch._C.Booltype.get()): + register_func_annotation[arg.name] = bool + impl_func_annotation[arg.name] = bool + elif arg.type.isSubtypeOf(torch._C.Stringtype.get()): + register_func_annotation[arg.name] = str + impl_func_annotation[arg.name] = str + else: + raise ValueError("arg type is not handled") + + input_signature = ", ".join(arg_list) + + plugin_signature = f"def add_plugin_desc({input_signature}):" + + plugin_impl_arg_list = arg_list + plugin_impl_arg_list.append("outputs") + plugin_impl_arg_list.append("stream") + plugin_impl_input = ", ".join(plugin_impl_arg_list) + plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):" + + register_func_annotation["return"] = Tuple[trtp.TensorDesc] + + impl_func_annotation["outputs"] = Tuple[trtp.Tensor] + impl_func_annotation["stream"] = int + + return ( + input_signature, + plugin_signature, + plugin_impl_signature, + register_func_annotation, + impl_func_annotation, + ) + + # Use the helper function to get the required signatures + ( + input_signature, + plugin_signature, + plugin_impl_signature, + register_func_annotation, + impl_func_annotation, + ) = generate_signature(torch_op) + + def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + syms_args = [] + tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] + + for tensor_arg in tensor_args: + + sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} + syms_arg = [ + mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) + for k, v in sample.items() + ] + syms_args.append(syms_arg) + + with FakeTensorMode() as fake_mode: + fake_args = [] + for syms_arg in syms_args: + fake_arg = torch.randn(syms_arg) + fake_args.append(fake_arg) + + output = torch_op(*fake_args, **kwargs) + + # We assume that number of dimensions are the same in torch op + shape_calc_fns = [None] * args[0].ndim + for i in range(args[0].ndim): + input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args] + shape_calc_fns[i] = lambdify( + tuple(input_node_expr), output.shape[i].node.expr, "math" + ) + + out_desc = tensor_args[0].like() + for i in range(out_desc.ndim): + input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args] + if output.shape[i].node.expr is None: + raise ValueError(f"output.shape[{i}].node.expr cannot be None") + out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] + + return (out_desc,) + + codegen_plugin = f""" +{plugin_signature} + return _generic_plugin_desc({input_signature}) + """ + + _LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}") + + plugin_code = compile(codegen_plugin, "", "exec") + + globals()["_generic_plugin_desc"] = _generic_plugin_desc + + plugin = FunctionType( + plugin_code.co_consts[0], + globals(), + "plugin", + ) + + # Function annotation is required for dynamic function to work in TensorRT.Plugin + plugin.__annotations__ = register_func_annotation + + trtp.register(plugin_name)(plugin) + + def _generic_plugin_impl( + outputs: Tuple[trtp.Tensor], stream: int, *args: Any, **kwargs: Any + ) -> None: + tensor_args = [elem for elem in args if isinstance(elem, trtp.Tensor)] + non_tensor_args = [elem for elem in args if not isinstance(elem, trtp.Tensor)] + in_tensors = [torch.as_tensor(i, device="cuda") for i in tensor_args] + + dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + + stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(stream): + out_tensors = torch_op(*in_tensors, *non_tensor_args, **kwargs) + if isinstance(out_tensors, torch.Tensor): + out_tensors = (out_tensors,) + [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + + plugin_impl_func = f""" +{plugin_impl_signature} + _generic_plugin_impl(outputs, stream, {input_signature}) + """ + + _LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}") + + plugin_impl_code = compile(plugin_impl_func, "", "exec") + + globals()["_generic_plugin_impl"] = _generic_plugin_impl + + plugin_impl = FunctionType(plugin_impl_code.co_consts[0], globals(), "plugin_impl") + + plugin_impl.__annotations__ = impl_func_annotation + + trtp.impl(plugin_name)(plugin_impl) + + +def generate_plugin(plugin_name: str) -> None: + """ + Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs. + + Args: + plugin_name: the plugin name that is used to generate the plugin automatically. + There should be existing kernels and pytorch custom operation for this plugin name. + """ + _generate_plugin(plugin_name) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index b880939b17..f6343fdb34 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np +import tensorrt as trt # Seems like a bug in TensorRT import tensorrt.plugin as trtp @@ -18,8 +19,6 @@ ) from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -58,13 +57,24 @@ def custom_kernel_converter( # Assuming TensorRT preserves kwargs order like PyTorch does non_tensor_inputs = plugin.input_attrs + kwargs = {} + + for arg in torch_schema.arguments: + if arg.default_value is not None: + kwargs[arg.name] = arg.default_value + non_tensor_args = args[len(tensor_inputs) :] non_tensor_kwargs = dict(zip(list(non_tensor_inputs.keys()), non_tensor_args)) - for k, v in non_tensor_kwargs.items(): + + for k, v in kwargs.items(): + if k in non_tensor_kwargs: + kwargs[k] = non_tensor_kwargs[k] + + for k, v in kwargs.items(): if isinstance(v, torch.fx.immutable_collections.immutable_list): - non_tensor_kwargs[k] = np.array(v) + kwargs[k] = np.array(v) - layer = ctx.net.add_plugin(plugin(*itensor_args, **non_tensor_kwargs)) + layer = ctx.net.add_plugin(plugin(*itensor_args, **kwargs)) assert layer, f"{namespace}::{name} plugin layer was not able to be created" _LOGGER.debug( f"Adding generated plugin for {namespace}::{name} to tensorrt network" diff --git a/pyproject.toml b/pyproject.toml index 87236f2c02..bc01a84038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ requires = [ "torch>=2.7.0.dev,<2.8.0", "pybind11==2.6.2", "numpy", + "sympy", ] build-backend = "setuptools.build_meta" diff --git a/tests/py/dynamo/conversion/test_automatic_plugin.py b/tests/py/dynamo/conversion/test_automatic_plugin.py new file mode 100644 index 0000000000..e843686f9f --- /dev/null +++ b/tests/py/dynamo/conversion/test_automatic_plugin.py @@ -0,0 +1,92 @@ +from typing import Tuple + +import torch +import torch.nn as nn +import torch_tensorrt +import triton +import triton.language as tl +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +@triton.jit +def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): + # Program ID determines the block of data each thread will process + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals * y_vals + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc] +def elementwise_mul(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],) + + # Launch the kernel + elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) + + return Z + + +@torch.library.register_fake("torchtrt_ex::elementwise_mul") +def elementwise_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + + +torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True +) + + +class TestAutomaticPlugin(DispatchTestCase): + @parameterized.expand( + [ + ((64, 64), torch.float), + ((256, 256), torch.int), + ] + ) + def test_mul_plugin_float(self, input_shape, dtype): + class elementwise_mul(nn.Module): + def forward(self, lhs, rhs): + return torch.ops.torchtrt_ex.elementwise_mul.default(lhs, rhs) + + inputs = [ + torch.randint(0, 5, input_shape, device="cuda", dtype=dtype), + torch.randint(0, 5, input_shape, device="cuda", dtype=dtype), + ] + + self.run_test(elementwise_mul(), inputs) + + +if __name__ == "__main__": + run_tests() + +# Example Usage +# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float) +# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float) + +# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B) + +# print("C (Addition):", C) +# print("D (Multiplication):", D) diff --git a/tests/py/dynamo/conversion/test_automatic_plugin_with_attrs.py b/tests/py/dynamo/conversion/test_automatic_plugin_with_attrs.py new file mode 100644 index 0000000000..da0d6bfdfb --- /dev/null +++ b/tests/py/dynamo/conversion/test_automatic_plugin_with_attrs.py @@ -0,0 +1,86 @@ +from typing import Tuple + +import torch +import torch.nn as nn +import torch_tensorrt +import triton +import triton.language as tl +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +@triton.jit +def elementwise_scale_mul_kernel(X, Y, Z, a, b, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals * y_vals * a + b + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +@torch.library.custom_op("torchtrt_ex::elementwise_scale_mul", mutates_args=()) # type: ignore[misc] +def elementwise_scale_mul( + X: torch.Tensor, Y: torch.Tensor, b: float = 0.2, a: int = 2 +) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta["BLOCK_SIZE"],) + + # Launch the kernel with parameters a and b + elementwise_scale_mul_kernel[grid](X, Y, Z, a, b, BLOCK_SIZE=BLOCK_SIZE) + + return Z + + +@torch.library.register_fake("torchtrt_ex::elementwise_scale_mul") +def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Tensor: + return x + + +torch_tensorrt.dynamo.conversion.plugins.custom_op( + "torchtrt_ex::elementwise_scale_mul", supports_dynamic_shapes=True +) + + +class TestAutomaticPlugin(DispatchTestCase): + @parameterized.expand( + [ + ((64, 64), torch.float), + ((256, 256), torch.int), + ] + ) + def test_scale_mul_plugin_float(self, input_shape, dtype): + class elementwise_scale_mul(nn.Module): + def forward(self, lhs, rhs): + return torch.ops.torchtrt_ex.elementwise_scale_mul.default( + lhs, rhs, b=1, a=0 + ) + + inputs = [ + torch.randint(0, 5, input_shape, device="cuda", dtype=dtype), + torch.randint(0, 5, input_shape, device="cuda", dtype=dtype), + ] + + self.run_test(elementwise_scale_mul(), inputs) + + +if __name__ == "__main__": + run_tests()