Skip to content

Commit

Permalink
fix: Fix CI issues due to unintended fake tensor creation in torch.co…
Browse files Browse the repository at this point in the history
…mpile tests (#3416)
  • Loading branch information
peri044 authored Feb 27, 2025
1 parent 9b78101 commit 4f0bb6f
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
Expand Down Expand Up @@ -256,48 +257,54 @@ def prepare_inputs(
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
disable_memory_format_check: bool = False,
) -> Any:
if inputs is None:
return None

elif isinstance(inputs, Input):
return inputs
"""
We take a nested group of torch.Tensors or scalars and convert them into torchtrt.Input's
"""
# Any tensors created inside this call will be FakeTensors if it's inside a torch.compile session
# So, we disable fake mode temporarily.
with unset_fake_temporarily():
if inputs is None:
return None

elif isinstance(inputs, (torch.Tensor, int, float, bool)):
return Input.from_tensor(
torch.tensor(inputs),
disable_memory_format_check=disable_memory_format_check,
)
elif isinstance(inputs, Input):
return inputs

elif isinstance(inputs, (list, tuple)):
torchtrt_input_list = []
for input_obj in inputs:
torchtrt_input = prepare_inputs(
input_obj, disable_memory_format_check=disable_memory_format_check
elif isinstance(inputs, (torch.Tensor, int, float, bool)):
return Input.from_tensor(
torch.tensor(inputs),
disable_memory_format_check=disable_memory_format_check,
)
torchtrt_input_list.append(torchtrt_input)

return (
torchtrt_input_list
if isinstance(inputs, list)
else tuple(torchtrt_input_list)
)

elif isinstance(inputs, dict):
torchtrt_inputs_dict: Dict[Any, Any] = dict()
elif isinstance(inputs, (list, tuple)):
torchtrt_input_list = []
for input_obj in inputs:
torchtrt_input = prepare_inputs(
input_obj, disable_memory_format_check=disable_memory_format_check
)
torchtrt_input_list.append(torchtrt_input)

for key, input_obj in inputs.items():
torchtrt_input = prepare_inputs(
input_obj, disable_memory_format_check=disable_memory_format_check
return (
torchtrt_input_list
if isinstance(inputs, list)
else tuple(torchtrt_input_list)
)
torchtrt_inputs_dict[key] = torchtrt_input

return torchtrt_inputs_dict
elif isinstance(inputs, dict):
torchtrt_inputs_dict: Dict[Any, Any] = dict()

else:
raise ValueError(
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
)
for key, input_obj in inputs.items():
torchtrt_input = prepare_inputs(
input_obj, disable_memory_format_check=disable_memory_format_check
)
torchtrt_inputs_dict[key] = torchtrt_input

return torchtrt_inputs_dict

else:
raise ValueError(
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
)


def parse_complex_tensor_structs(
Expand Down

0 comments on commit 4f0bb6f

Please sign in to comment.