diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index 998c378be2..5db6691599 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -18,7 +18,7 @@ "./tensor_parallel_llama3" ) # Import should be after initialization of the TRT-LLM plugin .so path -import tensorrt_llm +import torch_tensorrt logger.info(f"Starting PyTorch TP example on rank {_rank}.") assert ( diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 837648fdb4..9fe1a33bc5 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -15,7 +15,6 @@ device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_simple_example" ) -import tensorrt_llm """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ef04745562..2e8144275f 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -59,17 +59,25 @@ def aot_torch_tensorrt_aten_backend( _pretraced_backend, settings=settings, engine_cache=engine_cache ) settings_aot_autograd = {} - settings_aot_autograd["decompostions"] = get_decompositions( + settings_aot_autograd["decompositions"] = get_decompositions( settings.enable_experimental_decompositions ) # This is added since detach lowering leads to alias nodes # Error - View operation returned a tensor that is the same as the input base tensor # torch nop_decompositions in torch/_decomp/decompositions.py - if aten.detach in settings_aot_autograd["decompositions"]: - del settings_aot_autograd["decompositions"][aten.detach] + # transpose key deleted since not desirable to lower it to permute + to_delete = { + key + for key in settings_aot_autograd["decompositions"] + if "transpose" in key._name or "detach" in key._name + } + + for key in to_delete: + del settings_aot_autograd["decompositions"][key] + return aot_autograd( fw_compiler=_pretraced_backend_autograd, - decompositions=get_decompositions(settings.enable_experimental_decompositions), + decompositions=settings_aot_autograd["decompositions"], )(gm, sample_inputs) diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py index 17850fabce..79611c7552 100644 --- a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -3,6 +3,7 @@ import logging from typing import Dict, Sequence, Tuple, Union +import tensorrt as trt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl @@ -16,8 +17,6 @@ tensorrt_fused_nccl_reduce_scatter_op, ) -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) if load_tensorrt_llm(): @@ -30,7 +29,7 @@ def fused_nccl_gather( kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.distributed.nccl_gather( + return impl.nccl_ops.nccl_gather( ctx, target, SourceIR.ATEN, @@ -46,7 +45,7 @@ def fused_nccl_reduce_scatter( kwargs: Dict[str, Argument], name: str, ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: - return impl.distributed.nccl_reduce_scatter( + return impl.nccl_ops.nccl_reduce_scatter( ctx, target, SourceIR.ATEN, @@ -54,7 +53,6 @@ def fused_nccl_reduce_scatter( [args[0]], ) - breakpoint() else: _LOGGER.debug( "Did not load torch.distributed converters since TensorRT-LLM is not available" diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py index 013268f803..c28c5bcc7d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -3,12 +3,11 @@ from typing import Optional, Tuple, Union import numpy as np +import tensorrt as trt from torch.fx.node import Argument, Target from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name -import tensorrt as trt - # class for AllReduce class AllReduceStrategy(IntEnum): @@ -94,7 +93,7 @@ def nccl_reduce_scatter( "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 ) - p_dtype = trt.float16 + p_dtype = trt.float32 pf_dtype = trt.PluginField( "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py b/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py index e2edec3d28..c9512165a3 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py @@ -106,7 +106,12 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None: if op_target in shape_inference_funcs: new_shape = shape_inference_funcs[op_target](node) - real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype) + new_node_dtype = None + if node.meta["val"].dtype == torch.complex64: + new_node_dtype = torch.float32 + else: + new_node_dtype = torch.float64 + real_tensor = torch.empty(new_shape, dtype=new_node_dtype) node.meta["val"] = fake_mode.from_tensor(real_tensor) else: print("No shape for the inference function", {op_name}) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py index f709f177d6..02cb2ccd56 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -49,7 +49,6 @@ def fuse_distributed_ops( == torch.ops._c10d_functional.wait_tensor.default ): wait_tensor_node = list(node.users)[0] - fused_op = None if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default: with gm.graph.inserting_after(wait_tensor_node): fused_node = gm.graph.create_node( @@ -58,11 +57,12 @@ def fuse_distributed_ops( args=(node.args[0], node.args[1], node.args[2]), ) else: - fused_node = gm.graph.create_node( - op="call_function", - target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function - args=(node.args[0], node.args[1], node.args[2], node.args[3]), - ) + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) wait_tensor_node.replace_all_uses_with(fused_node) fused_node.meta.update(node.meta) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 9086de657f..5a2d4b3a1f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs ] + + for i, contiguous_input in enumerate(contiguous_inputs): + if contiguous_input.dtype == torch.complex64: + contiguous_input_real = contiguous_input.real + contiguous_input_imag = contiguous_input.imag + contiguous_inputs[i] = torch.stack( + (contiguous_input_real, contiguous_input_imag), dim=-1 + ) + with ( torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") if self.profiling_enabled diff --git a/tests/py/dynamo/distributed/distributed_utils.py b/tests/py/dynamo/distributed/distributed_utils.py new file mode 100644 index 0000000000..e3062249fa --- /dev/null +++ b/tests/py/dynamo/distributed/distributed_utils.py @@ -0,0 +1,60 @@ +import logging +import os + +import numpy as np +import tensorrt as trt +import torch +import torch.distributed as dist +from torch.distributed._tensor.device_mesh import init_device_mesh + + +def set_environment_variables_pytest(): + os.environ["WORLD_SIZE"] = str(1) + os.environ["RANK"] = str(0) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(29500) + os.environ["USE_TRTLLM_PLUGINS"] = "1" + + +def initialize_logger(rank, logger_file_name): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger + + +# This is required for env initialization since we use mpirun +def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so" + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + logger = initialize_logger(rank, logger_file_name) + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank, logger diff --git a/tests/py/dynamo/distributed/test_distributed_simple_example.py b/tests/py/dynamo/distributed/test_distributed_simple_example.py new file mode 100644 index 0000000000..845655b000 --- /dev/null +++ b/tests/py/dynamo/distributed/test_distributed_simple_example.py @@ -0,0 +1,92 @@ +import time + +import tensorrt as trt +import torch +import torch.nn as nn +import torch_tensorrt +from distributed_utils import initialize_distributed_env +from torch.distributed._tensor import Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) + +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "./tensor_parallel_simple_example" +) + +""" +This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py +""" + + +class ToyModel(nn.Module): + """MLP based model""" + + def __init__(self): + super(ToyModel, self).__init__() + self.in_proj = nn.Linear(10, 3200) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(3200, 1600) + self.in_proj2 = nn.Linear(1600, 500) + self.out_proj2 = nn.Linear(500, 100) + + def forward(self, x): + x = self.out_proj(self.relu(self.in_proj(x))) + x = self.relu(x) + x = self.out_proj2(self.relu(self.in_proj2(x))) + return x + + +logger.info(f"Starting PyTorch TP example on rank {_rank}.") + +# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. +tp_model = ToyModel().to("cuda") + + +# Custom parallelization plan for the model +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), + "in_proj2": ColwiseParallel(input_layouts=Shard(0)), + "out_proj2": RowwiseParallel(output_layouts=Shard(0)), + }, +) +torch.manual_seed(0) +inp = torch.rand(20, 10, device="cuda") +python_result = tp_model(inp) + + +backend = "torch_tensorrt" +tp_model = torch.compile( + tp_model, + backend=backend, + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float32, torch.float16}, + "use_python_runtime": True, + "min_block_size": 1, + "use_aot_joint_export": False, + }, + dynamic=False, +) + +for i in range(10): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device="cuda") + start = time.time() + output = tp_model(inp) + end = time.time() + if i == 0: + logger.info(f"Compilation time is {end-start}") + assert ( + python_result - output + ).std() < 0.01, "Compilation result is not correct." + elif _rank == 0: + logger.info(f"Inference time is {end-start}") diff --git a/tests/py/dynamo/distributed/test_nccl_ops.py b/tests/py/dynamo/distributed/test_nccl_ops.py new file mode 100644 index 0000000000..964aae481d --- /dev/null +++ b/tests/py/dynamo/distributed/test_nccl_ops.py @@ -0,0 +1,76 @@ +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +from distributed_utils import set_environment_variables_pytest +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +set_environment_variables_pytest() +dist.init_process_group(backend="nccl", init_method="env://") +group = dist.new_group(ranks=[0]) +group_name = group.group_name +world_size = 1 + +from conversion.harness import DispatchTestCase + + +class TestGatherNcclOpsConverter(DispatchTestCase): + @parameterized.expand([(8)]) + def test_nccl_ops(self, linear_layer_dim): + class DistributedGatherModel(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc = torch.nn.Linear(input_dim, input_dim) + + def forward(self, x): + x = self.fc(x) + gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( + x, world_size, group_name + ) + gathered_tensor = torch.ops._c10d_functional.wait_tensor( + gathered_tensor + ) + return gathered_tensor + + inputs = [torch.randn(1, linear_layer_dim).to("cuda")] + self.run_test( + DistributedGatherModel(linear_layer_dim).cuda(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + @parameterized.expand([(8)]) + def test_nccl_ops_scatter(self, linear_layer_dim): + + class DistributedReduceScatterModel(nn.Module): + def __init__(self, input_dim): + super().__init__() + self.fc = torch.nn.Linear(input_dim, input_dim) + + def forward(self, x): + x = self.fc(x) + scatter_reduce_tensor = ( + torch.ops._c10d_functional.reduce_scatter_tensor( + x, "sum", world_size, group_name + ) + ) + scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( + scatter_reduce_tensor + ) + return scatter_reduce_tensor + + inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] + + self.run_test( + DistributedReduceScatterModel(linear_layer_dim).cuda(), + inputs, + use_dynamo_tracer=True, + enable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/distributed/test_nccl_ops.sh b/tests/py/dynamo/distributed/test_nccl_ops.sh new file mode 100644 index 0000000000..5a46a85149 --- /dev/null +++ b/tests/py/dynamo/distributed/test_nccl_ops.sh @@ -0,0 +1,132 @@ +#!/bin/bash + +check_command() { + command -v "$1" >/dev/null 2>&1 +} + +ensure_installed() { + local pkg="$1" + if ! check_command "$pkg"; then + echo "$pkg is not installed. Installing $pkg..." + + # Determine if sudo is needed + if check_command sudo; then + SUDO="sudo" + else + SUDO="" + fi + + # Detect OS and install accordingly + OS="$(uname -s)" + if [[ "$OS" == "Linux" ]]; then + if check_command apt-get; then + $SUDO apt-get update && $SUDO apt-get install -y "$pkg" + fi + else + echo "Unsupported OS: $OS. Please install $pkg manually." + exit 1 + fi + else + echo "$pkg is already installed." + fi +} + +ensure_mpi_installed() { + local pkg="$1" + if dpkg -l | grep -q "$pkg"; then + echo "$pkg is already installed." + else + echo "$pkg is not installed. Installing $pkg..." + + # Determine if sudo is needed + if check_command sudo; then + SUDO="sudo" + else + SUDO="" + fi + + # Detect OS and install accordingly + OS="$(uname -s)" + if [[ "$OS" == "Linux" ]]; then + if check_command apt-get; then + $SUDO apt-get update && $SUDO apt-get install -y "$pkg" + fi + else + echo "Unsupported OS: $OS. Please install $pkg manually." + exit 1 + fi + fi +} + +ensure_pytest_installed(){ + if check_command pip; then + echo "pip is installed, installing pytest..." + pip install pytest + else + echo "pip is not installed. Please install pip first." + exit 1 + fi +} + +echo "Setting up the environment" + +OS="$(uname -s)" +ARCH="$(uname -m)" + + +#getting the file name for TensorRT-LLM download +if [[ "$OS" == "Linux" && "$ARCH" == "x86_64"]]; then + FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl" +elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64"]]; then + FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl" +else: + echo "Unsupported platform: OS=$OS ARCH=$ARCH + exit 1 +fi + +# Download the selected file +URL="https://pypi.nvidia.com/tensorrt-llm/$FILE" +echo "Downloading $FILE from $URL..." + +echo "Downloading ...." +#Installing wget +ensure_installed wget +#Downloading the package +wget "$URL" +echo "Download complete: $FILE" + +UNZIP_DIR="tensorrt_llm_unzip" +if [[ ! -d "$UNZIP_DIR" ]]; then + echo "Creating directory: $UNZIP_DIR" + mkdir -p "$UNZIP_DIR" + echo "extracting $FILE to $UNZIP_DIR ..." + #Installing unzip + ensure_installed unzip + #unzip the TensorRT-LLM package + unzip -q "$FILE" -d "$UNZIP_DIR" + echo "Unzip complete" +fi + + +export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" +echo ${TRTLLM_PLUGINS_PATH} + +ensure_mpi_installed libmpich-dev +ensure_mpi_installed libopenmpi-dev + +run_tests() { + cd .. + export PYTHONPATH=$(pwd) + echo "Running pytest on distributed/test_nccl_ops.py..." + pytest distributed/test_nccl_ops.py +} + +run_mpi_tests(){ + cd distributed + echo "Running test_distributed_simple_example with mpirun..."--- + mpirun -n 1 --allow-run-as-root python test_distributed_simple_example.py +} + +ensure_pytest_installed +run_tests +run_mpi_tests \ No newline at end of file