Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nccl ops correction changes #3387

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm
import torch_tensorrt

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)
import tensorrt_llm

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
Expand Down
16 changes: 12 additions & 4 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,25 @@ def aot_torch_tensorrt_aten_backend(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompostions"] = get_decompositions(
settings_aot_autograd["decompositions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
# transpose key deleted since not desirable to lower it to permute
to_delete = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this apply to all cases not just NCCL?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in the non distributed example? I am not sure about that answer, I added this for the llama3 example since I was issues in the model lowering and it was generating graph breaks at the wrong part, leading to complex input error. It can be added to all cases in case if we want to not lower transpose to permute.

Copy link
Collaborator Author

@apbose apbose Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the discussion

  1. detach: remove_detach does not help since the graph explicitly does not have detach ops to remove the nodes. Instead it encounters this in https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2153. This might be due to the %hook_result_3 = call_function[target=torch._dynamo.variables.tensor.prim_to_local](args = (%outputs_3,), kwargs = {}) where it moves the DTensor to local tensor and needs to detach the distributed operation. This is in tensor_parallel_simple_example.py

  2. transpose: transpose is more for tackling the tensor_parallel_llama3.py. The broad modification I did to handle the complex nos, are:
    a. Modifying the placeholder node shape and type
    b. Modifying the inputs to the reshape and slice ops with complex inputs
    c. Replace the complex tensorrt mul
    I see that if I decompose transpose to permute, the graph in gpu_0, has output of complex tensor mul as complex64 or complex 128 which goes as input to acc_* graph causing complex input error. Transpose being in the graph helps in it handling the complex input in gpu_0 graph partition only.

Regarding the discussion, would removal of transpose from decomposition affect the result- I would think no, since this is not removal of op like detach, but instead it is just that we do not lower it to permute. But you could provide me more insights if not

  1. Also if we would want it to be model specific and not apply to all models, I think it can be the next step to include it either in the UI or something like we do in the non distributed models, including in the torch_disabled_decomposition dictionary which applies to all model. Specifying UI, since we are talking about model specific disabled decomposition. As of now since this part of code applies to only distributed model, it should be good to go.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so the code needs to be restructured to make it clear that is is not the main codepath.

settings, engine_cache = parse_dynamo_kwargs(kwargs)
    if settings.use_aot_joint_export:
        return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
    logger.debug("Wrapping the backend with aot_autograd\n")
    _pretraced_backend_autograd = functools.partial(
        _pretraced_backend, settings=settings, engine_cache=engine_cache
    )
    settings_aot_autograd = {}
    settings_aot_autograd["decompositions"] = get_decompositions(
        settings.enable_experimental_decompositions
    )
    # This is added since detach lowering leads to alias nodes
    # Error - View operation returned a tensor that is the same as the input base tensor
    # torch nop_decompositions in torch/_decomp/decompositions.py
    # transpose key deleted since not desirable to lower it to permute
    to_delete = {...

Its not immediately obvious that most models will run through

if settings.use_aot_joint_export:
        return _pretraced_backend(gm, sample_inputs, settings, engine_cache)

And I am still not sure if such a broad change should be made even in the case of MGMN. How would a user/we know know that this change is needed?

key
for key in settings_aot_autograd["decompositions"]
if "transpose" in key._name or "detach" in key._name
}

for key in to_delete:
del settings_aot_autograd["decompositions"][key]

return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
decompositions=settings_aot_autograd["decompositions"],
)(gm, sample_inputs)


Expand Down
8 changes: 3 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Dict, Sequence, Tuple, Union

import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand All @@ -16,8 +17,6 @@
tensorrt_fused_nccl_reduce_scatter_op,
)

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():
Expand All @@ -30,7 +29,7 @@ def fused_nccl_gather(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_gather(
return impl.nccl_ops.nccl_gather(
ctx,
target,
SourceIR.ATEN,
Expand All @@ -46,15 +45,14 @@ def fused_nccl_reduce_scatter(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_reduce_scatter(
return impl.nccl_ops.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

breakpoint()
else:
_LOGGER.debug(
"Did not load torch.distributed converters since TensorRT-LLM is not available"
Expand Down
5 changes: 2 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from typing import Optional, Tuple, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name

import tensorrt as trt


# class for AllReduce
class AllReduceStrategy(IntEnum):
Expand Down Expand Up @@ -94,7 +93,7 @@ def nccl_reduce_scatter(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)

p_dtype = trt.float16
p_dtype = trt.float32
pf_dtype = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None:

if op_target in shape_inference_funcs:
new_shape = shape_inference_funcs[op_target](node)
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
new_node_dtype = None
if node.meta["val"].dtype == torch.complex64:
new_node_dtype = torch.float32
else:
new_node_dtype = torch.float64
real_tensor = torch.empty(new_shape, dtype=new_node_dtype)
node.meta["val"] = fake_mode.from_tensor(real_tensor)
else:
print("No shape for the inference function", {op_name})
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def fuse_distributed_ops(
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
Expand All @@ -58,11 +57,12 @@ def fuse_distributed_ops(
args=(node.args[0], node.args[1], node.args[2]),
)
else:
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function
args=(node.args[0], node.args[1], node.args[2], node.args[3]),
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

for i, contiguous_input in enumerate(contiguous_inputs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the C++ API not need these changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not clear on this aspect. Could you please let me know what would be required as part of this. I am running distributed python example and did not encounter the requirement for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean try running the example using the C++ runtime, Id expect that it doesnt handle complex numerics correctly

if contiguous_input.dtype == torch.complex64:
contiguous_input_real = contiguous_input.real
contiguous_input_imag = contiguous_input.imag
contiguous_inputs[i] = torch.stack(
(contiguous_input_real, contiguous_input_imag), dim=-1
)

with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down
60 changes: 60 additions & 0 deletions tests/py/dynamo/distributed/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
import os

import numpy as np
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh


def set_environment_variables_pytest():
os.environ["WORLD_SIZE"] = str(1)
os.environ["RANK"] = str(0)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(29500)
os.environ["USE_TRTLLM_PLUGINS"] = "1"


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so"

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
92 changes: 92 additions & 0 deletions tests/py/dynamo/distributed/test_distributed_simple_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import time

import tensorrt as trt
import torch
import torch.nn as nn
import torch_tensorrt
from distributed_utils import initialize_distributed_env
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""


class ToyModel(nn.Module):
"""MLP based model"""

def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(10, 3200)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(3200, 1600)
self.in_proj2 = nn.Linear(1600, 500)
self.out_proj2 = nn.Linear(500, 100)

def forward(self, x):
x = self.out_proj(self.relu(self.in_proj(x)))
x = self.relu(x)
x = self.out_proj2(self.relu(self.in_proj2(x)))
return x


logger.info(f"Starting PyTorch TP example on rank {_rank}.")

# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
tp_model = ToyModel().to("cuda")


# Custom parallelization plan for the model
tp_model = parallelize_module(
module=tp_model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
"in_proj2": ColwiseParallel(input_layouts=Shard(0)),
"out_proj2": RowwiseParallel(output_layouts=Shard(0)),
},
)
torch.manual_seed(0)
inp = torch.rand(20, 10, device="cuda")
python_result = tp_model(inp)


backend = "torch_tensorrt"
tp_model = torch.compile(
tp_model,
backend=backend,
options={
"truncate_long_and_double": True,
"enabled_precisions": {torch.float32, torch.float16},
"use_python_runtime": True,
"min_block_size": 1,
"use_aot_joint_export": False,
},
dynamic=False,
)

for i in range(10):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10, device="cuda")
start = time.time()
output = tp_model(inp)
end = time.time()
if i == 0:
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
logger.info(f"Inference time is {end-start}")
76 changes: 76 additions & 0 deletions tests/py/dynamo/distributed/test_nccl_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os

import torch
import torch.distributed as dist
import torch.nn as nn
from distributed_utils import set_environment_variables_pytest
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

set_environment_variables_pytest()
dist.init_process_group(backend="nccl", init_method="env://")
group = dist.new_group(ranks=[0])
group_name = group.group_name
world_size = 1

from conversion.harness import DispatchTestCase


class TestGatherNcclOpsConverter(DispatchTestCase):
@parameterized.expand([(8)])
def test_nccl_ops(self, linear_layer_dim):
class DistributedGatherModel(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.fc = torch.nn.Linear(input_dim, input_dim)

def forward(self, x):
x = self.fc(x)
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(
x, world_size, group_name
)
gathered_tensor = torch.ops._c10d_functional.wait_tensor(
gathered_tensor
)
return gathered_tensor

inputs = [torch.randn(1, linear_layer_dim).to("cuda")]
self.run_test(
DistributedGatherModel(linear_layer_dim).cuda(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)

@parameterized.expand([(8)])
def test_nccl_ops_scatter(self, linear_layer_dim):

class DistributedReduceScatterModel(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.fc = torch.nn.Linear(input_dim, input_dim)

def forward(self, x):
x = self.fc(x)
scatter_reduce_tensor = (
torch.ops._c10d_functional.reduce_scatter_tensor(
x, "sum", world_size, group_name
)
)
scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor(
scatter_reduce_tensor
)
return scatter_reduce_tensor

inputs = [torch.zeros(1, linear_layer_dim).to("cuda")]

self.run_test(
DistributedReduceScatterModel(linear_layer_dim).cuda(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)


if __name__ == "__main__":
run_tests()
Loading
Loading