Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU Inductor] Compile error when passing float16 tensors to vector_norm + remainder #97758

Closed
Kristoff-starling opened this issue Mar 28, 2023 · 4 comments
Assignees
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Kristoff-starling
Copy link
Contributor

Kristoff-starling commented Mar 28, 2023

🐛 Describe the bug

The following program works fine in eager mode but triggers a compilation error under torch.compile.

import torch

def fn(x, y):
    t = torch.linalg.vector_norm(x)
    return torch.remainder(y, t)

x = torch.rand([1], dtype=torch.float16)
y = torch.rand([1], dtype=torch.float16)

ret_eager = fn(x, y)
print('==== Eager mode OK! ====')

compiled = torch.compile(fn)
print('==== torchcomp compilation OK! ====')

ret_compiled = compiled(x, y)
print('==== torchcomp mode OK! ====')
Error log
==== Eager mode OK! ====
==== torchcomp compilation OK! ====
Traceback (most recent call last):
  File "repro.py", line 16, in <module>
    ret_compiled = compiled(x, y)
  File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 235, in _fn
    return fn(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/eval_frame.py", line 372, in catch_errors
    return callback(frame, cache_size, hooks)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 412, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 110, in _fn
    return fn(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 269, in _convert_frame_assert
    return _compile(
  File "python3.10/site-packages/torch/_dynamo/utils.py", line 166, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 331, in _compile
    out_code = transform_code_object(code, transform)
  File "python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
    transformations(instructions, code_options)
  File "python3.10/site-packages/torch/_dynamo/convert_frame.py", line 318, in transform
    tracer.run()
  File "python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1854, in run
    super().run()
  File "python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 604, in run
    and self.step()
  File "python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 564, in step
    getattr(self, inst.opname)(inst)
  File "python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1933, in RETURN_VALUE
    self.output.compile_subgraph(
  File "python3.10/site-packages/torch/_dynamo/output_graph.py", line 581, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "python3.10/site-packages/torch/_dynamo/output_graph.py", line 651, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "python3.10/site-packages/torch/_dynamo/utils.py", line 166, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.10/site-packages/torch/_dynamo/output_graph.py", line 730, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "python3.10/site-packages/torch/_dynamo/output_graph.py", line 726, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1088, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "python3.10/site-packages/torch/__init__.py", line 1527, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "python3.10/site-packages/torch/_inductor/compile_fx.py", line 568, in compile_fx
    return aot_autograd(
  File "python3.10/site-packages/torch/_dynamo/backends/common.py", line 62, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3047, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "python3.10/site-packages/torch/_dynamo/utils.py", line 166, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2687, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1794, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1960, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1260, in aot_dispatch_base
    compiled_fw = compiler(fw_module, flat_args)
  File "python3.10/site-packages/torch/_dynamo/utils.py", line 166, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.10/site-packages/torch/_inductor/compile_fx.py", line 532, in fw_compiler_base
    return inner_compile(
  File "python3.10/site-packages/torch/_dynamo/debug_utils.py", line 622, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "python3.10/site-packages/torch/_inductor/compile_fx.py", line 184, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "python3.10/site-packages/torch/_inductor/graph.py", line 650, in compile_to_fn
    return self.compile_to_module().call
  File "python3.10/site-packages/torch/_dynamo/utils.py", line 166, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.10/site-packages/torch/_inductor/graph.py", line 626, in compile_to_module
    mod = PyCodeCache.load(code, linemap=linemap)
  File "python3.10/site-packages/torch/_inductor/codecache.py", line 645, in load
    return cls.load_by_key_path(key, path, linemap)
  File "python3.10/site-packages/torch/_inductor/codecache.py", line 660, in load_by_key_path
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_/xz/cxznrwobivy2b4xrapdimk4nfs73werm7cdeq4yb7y5ttqsly6p5.py", line 42, in <module>
    async_compile.wait(globals())
  File "python3.10/site-packages/torch/_inductor/codecache.py", line 876, in wait
    scope[key] = result.result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "python3.10/site-packages/torch/_inductor/codecache.py", line 853, in task
    return CppCodeCache.load(source_code).kernel
  File "python3.10/site-packages/torch/_inductor/codecache.py", line 629, in load
    raise exc.CppCompileError(cmd, e.output) from e
torch._dynamo.exc.BackendCompilerFailed: backend='debug_wrapper' raised:
CppCompileError: C++ compile error

Command:
g++ /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp -shared -fPIC -Wall -std=c++17 -Wno-unused-variable -Ipython3.10/site-packages/torch/include -Ipython3.10/site-packages/torch/include/torch/csrc/api/include -Ipython3.10/site-packages/torch/include/TH -Ipython3.10/site-packages/torch/include/THC -I/usr/include/python3.10 -Lpython3.10/site-packages/torch/lib -L/usr/lib/x86_64-linux-gnu -lc10 -ltorch -ltorch_cpu -ltorch_python -lgomp -DCPU_CAPABILITY_AVX2 -O3 -ffast-math -fno-finite-math-only -march=native -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS -o/tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.so

Output:
In file included from python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h:8,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/vec.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional_base.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional.h:3,
                 from /tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:10,
                 from /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:2:
python3.10/site-packages/torch/include/ATen/cpu/vec/vec_base.h:1025: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
 1025 | # pragma unroll
      | 
In file included from python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h:10,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/vec.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional_base.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional.h:3,
                 from /tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:10,
                 from /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:2:
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h:438: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  438 | #pragma unroll
      | 
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_float.h:442: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  442 | #pragma unroll
      | 
In file included from python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h:12,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/vec.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional_base.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional.h:3,
                 from /tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:10,
                 from /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:2:
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h:693: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  693 | #pragma unroll
      | 
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h:698: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  698 | #pragma unroll
      | 
In file included from python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h:13,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/vec.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional_base.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional.h:3,
                 from /tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:10,
                 from /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:2:
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h:403: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  403 | #pragma unroll
      | 
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_double.h:407: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  407 | #pragma unroll
      | 
In file included from python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256.h:14,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/vec.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional_base.h:6,
                 from python3.10/site-packages/torch/include/ATen/cpu/vec/functional.h:3,
                 from /tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:10,
                 from /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:2:
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h:287: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  287 | # pragma unroll
      | 
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h:295: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  295 | # pragma unroll
      | 
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h:307: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  307 | # pragma unroll
      | 
python3.10/site-packages/torch/include/ATen/cpu/vec/vec256/vec256_int.h:315: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  315 | # pragma unroll
      | 
In file included from /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:2:
/tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:77: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
   77 | #pragma unroll
      | 
/tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:86: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
   86 | #pragma unroll
      | 
/tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:122: warning: ignoring ‘#pragma unroll ’ [-Wunknown-pragmas]
  122 | #pragma unroll
      | 
/tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp: In function ‘void kernel(const half*, const half*, half*)’:
/tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:14:24: error: no matching function for call to ‘mod(float&, c10::Half&)’
   14 |         auto tmp6 = mod(tmp0, tmp5);
      |                     ~~~^~~~~~~~~~~~
In file included from /tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:2:
/tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:19:32: note: candidate: ‘template<class T> T mod(T, T)’
   19 | template <typename T> inline T mod(T a, T b) { return a % b; }
      |                                ^~~
/tmp/torchinductor_/hw/chwr6vy6e6sd25sfh42qtywkuf2emodexm2aomp3lbrcxwznfwyi.h:19:32: note:   template argument deduction/substitution failed:
/tmp/torchinductor_/i7/ci74scdbqcmpwvdm4s5hwpfw5ppn2c7qpuewodtkp7jpyumupg5r.cpp:14:24: note:   deduced conflicting types for parameter ‘T’ (‘float’ and ‘c10::Half’)
   14 |         auto tmp6 = mod(tmp0, tmp5);
      |                     ~~~^~~~~~~~~~~~



You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Some details for debugging:

  • The program works fine on cuda.
  • The program works fine when passing float32 tensors to the function.
  • Both operators (vector_norm and remainder) are necessary for triggering the issue.a

Versions

Environment [Click to expand]
PyTorch version: 2.1.0.dev20230327+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: 14.0.0-1ubuntu1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.6 (main, Mar 10 2023, 10:55:28) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.19.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.6.124
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090

Nvidia driver version: 525.78.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.960
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.1
[pip3] pytorch-triton==2.1.0+e650d3708b
[pip3] torch==2.1.0.dev20230327+cu117
[pip3] torchaudio==2.1.0.dev20230328+cu117
[pip3] torchvision==0.16.0.dev20230327+cu117
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.21.5           py39he7a7128_1  
[conda] numpy-base                1.21.5           py39hf524024_1  
[conda] numpydoc                  1.2                pyhd3eb1b0_0

cc @ezyang @msaroufim @wconstab @ngimel @bdhirsh @anijain2305 @soumith

@pat749

This comment was marked as off-topic.

@anijain2305 anijain2305 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 25, 2023
@zhuhaozhe
Copy link
Collaborator

Hi. @Kristoff-starling. Thanks for reporting this. Could you help to recheck it with the latest version PyTorch? The FP16 with torch.compile should also work now.

@zhuhaozhe
Copy link
Collaborator

Feel free to reopen this if it is still a problem.

@ganler
Copy link
Contributor

ganler commented Dec 25, 2023

I can confirm it works now in latest PyTorch. Nice job!

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants