-
Notifications
You must be signed in to change notification settings - Fork 361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nccl ops correction changes #3387
base: main
Are you sure you want to change the base?
Changes from all commits
321f6de
f79560b
82298c5
66861f4
d2aaa44
9cf0e21
4b79bfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -364,6 +364,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . | |
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) | ||
for i in inputs | ||
] | ||
|
||
for i, contiguous_input in enumerate(contiguous_inputs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the C++ API not need these changes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not clear on this aspect. Could you please let me know what would be required as part of this. I am running distributed python example and did not encounter the requirement for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean try running the example using the C++ runtime, Id expect that it doesnt handle complex numerics correctly |
||
if contiguous_input.dtype == torch.complex64: | ||
contiguous_input_real = contiguous_input.real | ||
contiguous_input_imag = contiguous_input.imag | ||
contiguous_inputs[i] = torch.stack( | ||
(contiguous_input_real, contiguous_input_imag), dim=-1 | ||
) | ||
|
||
with ( | ||
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") | ||
if self.profiling_enabled | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import logging | ||
import os | ||
|
||
import numpy as np | ||
import tensorrt as trt | ||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed._tensor.device_mesh import init_device_mesh | ||
|
||
|
||
def set_environment_variables_pytest(): | ||
os.environ["WORLD_SIZE"] = str(1) | ||
os.environ["RANK"] = str(0) | ||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
os.environ["MASTER_PORT"] = str(29500) | ||
os.environ["USE_TRTLLM_PLUGINS"] = "1" | ||
|
||
|
||
def initialize_logger(rank, logger_file_name): | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") | ||
fh.setLevel(logging.INFO) | ||
logger.addHandler(fh) | ||
return logger | ||
|
||
|
||
# This is required for env initialization since we use mpirun | ||
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): | ||
local_rank = int( | ||
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) | ||
) | ||
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) | ||
|
||
# Set up environment variable to run with mpirun | ||
os.environ["RANK"] = str(local_rank) | ||
os.environ["WORLD_SIZE"] = str(world_size) | ||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
os.environ["MASTER_PORT"] = str(port) | ||
os.environ["TRTLLM_PLUGINS_PATH"] = "./tmp/lib/libnvinfer_plugin_tensorrt_llm.so" | ||
|
||
# Necessary to assign a device to each rank. | ||
torch.cuda.set_device(local_rank) | ||
|
||
# We use nccl backend | ||
dist.init_process_group("nccl") | ||
|
||
# set a manual seed for reproducibility | ||
torch.manual_seed(1111) | ||
|
||
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) | ||
rank = device_mesh.get_rank() | ||
assert rank == local_rank | ||
logger = initialize_logger(rank, logger_file_name) | ||
device_id = ( | ||
rank % torch.cuda.device_count() | ||
) # Ensure each rank gets a unique device | ||
torch.cuda.set_device(device_id) | ||
|
||
return device_mesh, world_size, rank, logger |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import time | ||
|
||
import tensorrt as trt | ||
import torch | ||
import torch.nn as nn | ||
import torch_tensorrt | ||
from distributed_utils import initialize_distributed_env | ||
from torch.distributed._tensor import Shard | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
RowwiseParallel, | ||
parallelize_module, | ||
) | ||
|
||
device_mesh, _world_size, _rank, logger = initialize_distributed_env( | ||
"./tensor_parallel_simple_example" | ||
) | ||
|
||
""" | ||
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py | ||
""" | ||
|
||
|
||
class ToyModel(nn.Module): | ||
"""MLP based model""" | ||
|
||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.in_proj = nn.Linear(10, 3200) | ||
self.relu = nn.ReLU() | ||
self.out_proj = nn.Linear(3200, 1600) | ||
self.in_proj2 = nn.Linear(1600, 500) | ||
self.out_proj2 = nn.Linear(500, 100) | ||
|
||
def forward(self, x): | ||
x = self.out_proj(self.relu(self.in_proj(x))) | ||
x = self.relu(x) | ||
x = self.out_proj2(self.relu(self.in_proj2(x))) | ||
return x | ||
|
||
|
||
logger.info(f"Starting PyTorch TP example on rank {_rank}.") | ||
|
||
# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. | ||
tp_model = ToyModel().to("cuda") | ||
|
||
|
||
# Custom parallelization plan for the model | ||
tp_model = parallelize_module( | ||
module=tp_model, | ||
device_mesh=device_mesh, | ||
parallelize_plan={ | ||
"in_proj": ColwiseParallel(input_layouts=Shard(0)), | ||
"out_proj": RowwiseParallel(output_layouts=Shard(0)), | ||
"in_proj2": ColwiseParallel(input_layouts=Shard(0)), | ||
"out_proj2": RowwiseParallel(output_layouts=Shard(0)), | ||
}, | ||
) | ||
torch.manual_seed(0) | ||
inp = torch.rand(20, 10, device="cuda") | ||
python_result = tp_model(inp) | ||
|
||
|
||
backend = "torch_tensorrt" | ||
tp_model = torch.compile( | ||
tp_model, | ||
backend=backend, | ||
options={ | ||
"truncate_long_and_double": True, | ||
"enabled_precisions": {torch.float32, torch.float16}, | ||
"use_python_runtime": True, | ||
"min_block_size": 1, | ||
"use_aot_joint_export": False, | ||
}, | ||
dynamic=False, | ||
) | ||
|
||
for i in range(10): | ||
# For TP, input needs to be same across all TP ranks. | ||
# Setting the random seed is to mimic the behavior of dataloader. | ||
torch.manual_seed(i) | ||
inp = torch.rand(20, 10, device="cuda") | ||
start = time.time() | ||
output = tp_model(inp) | ||
end = time.time() | ||
if i == 0: | ||
logger.info(f"Compilation time is {end-start}") | ||
assert ( | ||
python_result - output | ||
).std() < 0.01, "Compilation result is not correct." | ||
elif _rank == 0: | ||
logger.info(f"Inference time is {end-start}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from distributed_utils import set_environment_variables_pytest | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
set_environment_variables_pytest() | ||
dist.init_process_group(backend="nccl", init_method="env://") | ||
group = dist.new_group(ranks=[0]) | ||
group_name = group.group_name | ||
world_size = 1 | ||
|
||
from conversion.harness import DispatchTestCase | ||
|
||
|
||
class TestGatherNcclOpsConverter(DispatchTestCase): | ||
@parameterized.expand([(8)]) | ||
def test_nccl_ops(self, linear_layer_dim): | ||
class DistributedGatherModel(nn.Module): | ||
def __init__(self, input_dim): | ||
super().__init__() | ||
self.fc = torch.nn.Linear(input_dim, input_dim) | ||
|
||
def forward(self, x): | ||
x = self.fc(x) | ||
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor( | ||
x, world_size, group_name | ||
) | ||
gathered_tensor = torch.ops._c10d_functional.wait_tensor( | ||
gathered_tensor | ||
) | ||
return gathered_tensor | ||
|
||
inputs = [torch.randn(1, linear_layer_dim).to("cuda")] | ||
self.run_test( | ||
DistributedGatherModel(linear_layer_dim).cuda(), | ||
inputs, | ||
use_dynamo_tracer=True, | ||
enable_passes=True, | ||
) | ||
|
||
@parameterized.expand([(8)]) | ||
def test_nccl_ops_scatter(self, linear_layer_dim): | ||
|
||
class DistributedReduceScatterModel(nn.Module): | ||
def __init__(self, input_dim): | ||
super().__init__() | ||
self.fc = torch.nn.Linear(input_dim, input_dim) | ||
|
||
def forward(self, x): | ||
x = self.fc(x) | ||
scatter_reduce_tensor = ( | ||
torch.ops._c10d_functional.reduce_scatter_tensor( | ||
x, "sum", world_size, group_name | ||
) | ||
) | ||
scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor( | ||
scatter_reduce_tensor | ||
) | ||
return scatter_reduce_tensor | ||
|
||
inputs = [torch.zeros(1, linear_layer_dim).to("cuda")] | ||
|
||
self.run_test( | ||
DistributedReduceScatterModel(linear_layer_dim).cuda(), | ||
inputs, | ||
use_dynamo_tracer=True, | ||
enable_passes=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this apply to all cases not just NCCL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean in the non distributed example? I am not sure about that answer, I added this for the llama3 example since I was issues in the model lowering and it was generating graph breaks at the wrong part, leading to complex input error. It can be added to all cases in case if we want to not lower transpose to permute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the discussion
detach:
remove_detach
does not help since the graph explicitly does not have detach ops to remove the nodes. Instead it encounters this in https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2153. This might be due to the%hook_result_3 = call_function[target=torch._dynamo.variables.tensor.prim_to_local](args = (%outputs_3,), kwargs = {})
where it moves the DTensor to local tensor and needs to detach the distributed operation. This is intensor_parallel_simple_example.py
transpose:
transpose
is more for tackling thetensor_parallel_llama3.py
. The broad modification I did to handle the complex nos, are:a. Modifying the placeholder node shape and type
b. Modifying the inputs to the reshape and slice ops with complex inputs
c. Replace the complex tensorrt mul
I see that if I decompose transpose to permute, the graph in gpu_0, has output of complex tensor mul as complex64 or complex 128 which goes as input to acc_* graph causing complex input error. Transpose being in the graph helps in it handling the complex input in gpu_0 graph partition only.
Regarding the discussion, would removal of transpose from decomposition affect the result- I would think no, since this is not removal of op like detach, but instead it is just that we do not lower it to permute. But you could provide me more insights if not
torch_disabled_decomposition
dictionary which applies to all model. Specifying UI, since we are talking about model specific disabled decomposition. As of now since this part of code applies to only distributed model, it should be good to go.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so the code needs to be restructured to make it clear that is is not the main codepath.
Its not immediately obvious that most models will run through
And I am still not sure if such a broad change should be made even in the case of MGMN. How would a user/we know know that this change is needed?