You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After resolving issues from pytorch/pytorch#147096, a MetadataMismatchError occurs at _scaled_dot_product_flash_attention.
Traceback (most recent call last):
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1854, in _maybe_infer_fake
_check_fake_real_tensors(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_utils.py", line 196, in _check_fake_real_tensors
torch._prims.utils.compare_tensor_meta(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 193, in compare_tensor_meta
raise MetadataMismatchError(msg)
torch._subclasses.fake_tensor.MetadataMismatchError: Devices cpu and cuda:0 are not equal!
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/develop/TensorRT/examples/dynamo/torch_export_pg.py", line 151, in <module>
trt_model = torch_tensorrt.dynamo.compile(
File "/develop/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 670, in compile
exported_program = exported_program.run_decompositions(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 1405, in run_decompositions
return _decompose_exported_program(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 872, in _decompose_exported_program
) = _decompose_and_get_gm_with_new_signature_constants(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 491, in _decompose_and_get_gm_with_new_signature_constants
aten_export_artifact = _export_to_aten_ir(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/_trace.py", line 771, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1345, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1584, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
flat_f_outs = f(*flat_f_args)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run
self.env[node] = self.run_node(node)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6955, in run_node
result = super().run_node(n)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 236, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 316, in call_function
return target(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
return self._op(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 527, in __torch_dispatch__
outs_unwrapped = func._op_dk(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1269, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1812, in dispatch
return self._dispatch_impl(func, types, args, kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2388, in _dispatch_impl
return maybe_propagate_real_tensors(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2220, in maybe_propagate_real_tensors
self._maybe_infer_fake_kernel_from_pytree_out(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1959, in _maybe_infer_fake_kernel_from_pytree_out
fake_leaves = [
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1960, in <listcomp>
self._maybe_infer_fake(func, _fake_path, _fake_out, _real_out)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1873, in _maybe_infer_fake
raise MetadataMismatchError(
torch._subclasses.fake_tensor.MetadataMismatchError: Real tensor propagation found a metadata mismatch between fake tensor FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64) and real tensor 140486520641728, at output[6], for func: aten._scaled_dot_product_flash_attention.default
While executing %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%transpose_1, %transpose_2, %transpose_3), kwargs = {scale: 0.11785113019775793})
....
Original traceback:
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/paligemma/modeling_paligemma.py", line 504, in forward
image_features = self.get_image_features(pixel_values)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1190, in forward
return self.vision_model(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1091, in forward
encoder_outputs = self.encoder(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 902, in forward
layer_outputs = encoder_layer(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 643, in forward
hidden_states, attn_weights = self.self_attn(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 574, in forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
When executing the above script with torch 2.7.0.dev20250207+cu126, I get the following message:
###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################
1. Mismatched fake kernel.
torch.ops.aten._scaled_dot_product_flash_attention.default has a fake kernel implementation, but it has incorrect behavior, based on the real kernel.
The reason for the mismatch is: Devices cpu and cuda:0 are not equal!.
Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation.
But when executing with torch 2.7.0.dev20250228+cu126 (need to delete [0] after draft_export()), I get the following message instead:
##############################################################################################
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
You can now change back to torch.export.export()
##############################################################################################
Bug Description
After resolving issues from pytorch/pytorch#147096, a
MetadataMismatchError
occurs at_scaled_dot_product_flash_attention
.To Reproduce
Steps to reproduce the behavior:
Environment
pytorch-triton 3.2.0+git4b3bb1f8
torch 2.7.0.dev20250207+cu124
torch-tensorrt 2.7.0.dev0+5a4dd33ef /develop/TensorRT/py
torchvision 0.22.0.dev20250207+cu124
The text was updated successfully, but these errors were encountered: