Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] torch._subclasses.fake_tensor.MetadataMismatchError: Devices cpu and cuda:0 are not equal! (_scaled_dot_product_flash_attention) #3408

Open
chohk88 opened this issue Feb 24, 2025 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@chohk88
Copy link
Collaborator

chohk88 commented Feb 24, 2025

Bug Description

After resolving issues from pytorch/pytorch#147096, a MetadataMismatchError occurs at _scaled_dot_product_flash_attention.

Traceback (most recent call last):
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1854, in _maybe_infer_fake
    _check_fake_real_tensors(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_utils.py", line 196, in _check_fake_real_tensors
    torch._prims.utils.compare_tensor_meta(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 193, in compare_tensor_meta
    raise MetadataMismatchError(msg)
torch._subclasses.fake_tensor.MetadataMismatchError: Devices cpu and cuda:0 are not equal!

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/develop/TensorRT/examples/dynamo/torch_export_pg.py", line 151, in <module>
    trt_model = torch_tensorrt.dynamo.compile(
  File "/develop/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 670, in compile
    exported_program = exported_program.run_decompositions(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 1405, in run_decompositions
    return _decompose_exported_program(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 872, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 491, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/_trace.py", line 771, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1345, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1584, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
    flat_f_outs = f(*flat_f_args)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6955, in run_node
    result = super().run_node(n)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 236, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 316, in call_function
    return target(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 527, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1269, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1812, in dispatch
    return self._dispatch_impl(func, types, args, kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2388, in _dispatch_impl
    return maybe_propagate_real_tensors(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2220, in maybe_propagate_real_tensors
    self._maybe_infer_fake_kernel_from_pytree_out(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1959, in _maybe_infer_fake_kernel_from_pytree_out
    fake_leaves = [
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1960, in <listcomp>
    self._maybe_infer_fake(func, _fake_path, _fake_out, _real_out)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1873, in _maybe_infer_fake
    raise MetadataMismatchError(
torch._subclasses.fake_tensor.MetadataMismatchError: Real tensor propagation found a metadata mismatch between fake tensor FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64) and real tensor 140486520641728,  at output[6], for func: aten._scaled_dot_product_flash_attention.default

While executing %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%transpose_1, %transpose_2, %transpose_3), kwargs = {scale: 0.11785113019775793})

....

Original traceback:
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/paligemma/modeling_paligemma.py", line 504, in forward
    image_features = self.get_image_features(pixel_values)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1190, in forward
    return self.vision_model(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1091, in forward
    encoder_outputs = self.encoder(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 902, in forward
    layer_outputs = encoder_layer(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 643, in forward
    hidden_states, attn_weights = self.self_attn(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 574, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(


To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image


# 1. Model
DEVICE = torch.device("cuda:0")
model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.float16
).eval().to(DEVICE)
processor = PaliGemmaProcessor.from_pretrained(model_id)

prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

# 2. PyTorch
with torch.inference_mode():
    pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) #, use_cache=False)
    pyt_generation = pyt_generation[0][input_len:]
    pyt_decoded = processor.decode(pyt_generation, skip_special_tokens=True)
    print("=============================")
    print("PyTorch generated text:")
    print(pyt_decoded)
    print("=============================")

# (a) Dummy inputs  
batch_size = 1
dummy_input_ids = model_inputs["input_ids"] 
dummy_attention_mask = model_inputs["attention_mask"] 
dummy_pixel_values = model_inputs["pixel_values"]

dummy_inputs = {
    "input_ids": dummy_input_ids,
    "attention_mask": dummy_attention_mask,
    "pixel_values": dummy_pixel_values,
}

# (b) Dynamic shape 
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=1024)
dynamic_shapes = {
    "input_ids": {0: BATCH, 1: SEQ_LEN},
    "attention_mask": {0: BATCH, 1: SEQ_LEN},
    "pixel_values": {0: BATCH},
}
# (c) ExportedProgram  
# torch.export.export(
#     model,
#     args=(),
#     kwargs=dummy_inputs,
#     dynamic_shapes=dynamic_shapes,
#     strict=False,
# )


import torch
import torch.utils._pytree as pytree
import transformers

def flatten_hybridcache(hc: transformers.cache_utils.HybridCache):
    flat_tensors = []
    flat_tensors.append(hc.is_sliding)               # shape: [num_hidden_layers], bool
    flat_tensors.extend(hc.key_cache)                # List[Tensor]
    flat_tensors.extend(hc.value_cache)              # List[Tensor]

    context = {
        "max_cache_len": hc.max_cache_len,
        "max_batch_size": hc.max_batch_size,
        "head_dim": hc.head_dim,
        "dtype": hc.dtype,
        "num_key_value_heads": hc.num_key_value_heads,
        "num_layers": len(hc.key_cache),  # = len(hc.value_cache) = config.num_hidden_layers
    }

    return flat_tensors, context


def unflatten_hybridcache(flat_tensors, context):
    num_layers = context["num_layers"]

    is_sliding = flat_tensors[0]
    key_cache = flat_tensors[1 : 1 + num_layers]
    value_cache = flat_tensors[1 + num_layers : 1 + 2*num_layers]

    hc = transformers.cache_utils.HybridCache.__new__(transformers.cache_utils.HybridCache)

    hc.max_cache_len = context["max_cache_len"]
    hc.max_batch_size = context["max_batch_size"]
    hc.head_dim = context["head_dim"]
    hc.dtype = context["dtype"]
    hc.num_key_value_heads = context["num_key_value_heads"]
    hc.is_sliding = is_sliding
    hc.key_cache = list(key_cache)
    hc.value_cache = list(value_cache)

    return hc

# pytree register
pytree.register_pytree_node(
    transformers.cache_utils.HybridCache,
    flatten_hybridcache,
    unflatten_hybridcache
)

# from torch.export._trace import _export  
# exported_program = _export(
#     model,
#     args=(),
#     kwargs=dummy_inputs,
#     dynamic_shapes=dynamic_shapes,
#     strict=False,
#     allow_complex_guards_as_runtime_asserts=True,
# )

# torch.export._draft_export.draft_export
import torch.export._draft_export
exported_program = torch.export._draft_export.draft_export(
    model,
    args=(),
    kwargs=dummy_inputs,
    dynamic_shapes=dynamic_shapes,
    strict=False,
    # allow_complex_guards_as_runtime_asserts=True,
)


trt_model = torch_tensorrt.dynamo.compile(
    exported_program[0],
    inputs=dummy_inputs,
    enabled_precisions={torch.float32},
    truncate_double=True,
    device=DEVICE,
    disable_tf32=True,
    use_explicit_typing=True,
    use_fp32_acc=True,  
)

# TensorRT
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
with torch.inference_mode():
    trt_generation = trt_model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    trt_generation = trt_generation[0][input_len:]
    trt_decoded = processor.decode(trt_generation, skip_special_tokens=True)
    print("TensorRT generated text:")
    print(trt_decoded)

Environment

pytorch-triton 3.2.0+git4b3bb1f8
torch 2.7.0.dev20250207+cu124
torch-tensorrt 2.7.0.dev0+5a4dd33ef /develop/TensorRT/py
torchvision 0.22.0.dev20250207+cu124

@chohk88 chohk88 added the bug Something isn't working label Feb 24, 2025
@HolyWu
Copy link
Contributor

HolyWu commented Feb 27, 2025

Probably also fixed by #3336. I cannot access paligemma2 to confirm it.

@chohk88
Copy link
Collaborator Author

chohk88 commented Feb 28, 2025

Probably also fixed by #3336. I cannot access paligemma2 to confirm it.

I rebased onto the latest branch and tested it, but the same error still occurs.

@HolyWu
Copy link
Contributor

HolyWu commented Feb 28, 2025

It looks more like a PyTorch issue rather than Torch-TRT issue.

import torch
import torch.nn.functional as F
from torch.export._draft_export import draft_export


class MyModule(torch.nn.Module):
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        return F.scaled_dot_product_attention(query, key, value)


with torch.inference_mode():
    model = MyModule().eval().cuda().half()

    inputs = (
        torch.rand(32, 8, 128, 64, dtype=torch.half, device="cuda"),
        torch.rand(32, 8, 128, 64, dtype=torch.half, device="cuda"),
        torch.rand(32, 8, 128, 64, dtype=torch.half, device="cuda"),
    )

    exported_program = draft_export(model, inputs, strict=False)[0]
    print(exported_program._report)

When executing the above script with torch 2.7.0.dev20250207+cu126, I get the following message:

###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################

1. Mismatched fake kernel.
    torch.ops.aten._scaled_dot_product_flash_attention.default has a fake kernel implementation, but it has incorrect behavior, based on the real kernel.
    The reason for the mismatch is: Devices cpu and cuda:0 are not equal!.

    Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation.

But when executing with torch 2.7.0.dev20250228+cu126 (need to delete [0] after draft_export()), I get the following message instead:

##############################################################################################
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
You can now change back to torch.export.export()
##############################################################################################

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants