Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError when running matrix mutliplication example : CUDA: Error- invalid ptx #187

Closed
theevann opened this issue Aug 5, 2021 · 6 comments

Comments

@theevann
Copy link

theevann commented Aug 5, 2021

Hello,

First: Thank you for this great piece of work!

I installed triton from pip.
When I try to run the matrix multiplication example, I get the error:

RuntimeError: CUDA: Error- invalid ptx

GPU: GeForce GTX 1080 Ti

Output of nvcc --version:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Tue_Sep_15_19:10:02_PDT_2020
Cuda compilation tools, release 11.1, V11.1.74
Build cuda_11.1.TC455_06.29069683_0

And more details from pytorch collect_env.py:

PyTorch version: 1.9.0+cu111
CUDA used to build PyTorch: 11.1

OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Libc version: glibc-2.28

Python version: 3.7.3 (default, Jan 22 2021, 20:04:44)  [GCC 8.3.0] (64-bit runtime)
Python platform: Linux-4.19.0-16-amd64-x86_64-with-debian-10.10
CUDA runtime version: 11.1.74
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 455.32.00

Full trace:

------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-37-c179406deb17> in <module>
      2 a = torch.randn((51, 51), device='cuda', dtype=torch.float16)
      3 b = torch.randn((51, 51), device='cuda', dtype=torch.float16)
----> 4 c_0 = matmul(a, b, activation=None)
      5 c_1 = torch.matmul(a, b)
      6 print(c_0)

<ipython-input-36-f7a6e3ad6d5f> in matmul(a, b, activation)
     13         a, b, c, M, N, K, \
     14         a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
---> 15         ACTIVATION = activation,
     16     )
     17     # done; return the output tensor

~/.local/lib/python3.7/site-packages/triton/code_gen.py in __call__(self, *wargs, **kwargs)
    597 
    598     def __call__(self, *wargs, **kwargs):
--> 599         return self.kernel(*wargs, **kwargs, grid=self.grid)
    600 
    601 

~/.local/lib/python3.7/site-packages/triton/code_gen.py in __call__(self, *args, **meta)
    629             if key not in self.cache:
    630                 timings = {config: self._bench(*args, config=config, **meta) \
--> 631                         for config in self.configs}
    632                 self.cache[key] = builtins.min(timings, key=timings.get)
    633             config = self.cache[key]

~/.local/lib/python3.7/site-packages/triton/code_gen.py in <dictcomp>(.0)
    629             if key not in self.cache:
    630                 timings = {config: self._bench(*args, config=config, **meta) \
--> 631                         for config in self.configs}
    632                 self.cache[key] = builtins.min(timings, key=timings.get)
    633             config = self.cache[key]

~/.local/lib/python3.7/site-packages/triton/code_gen.py in _bench(self, config, *args, **meta)
    622         current = dict(meta, **config.meta)
    623         kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
--> 624         return triton.testing.do_bench(kernel_call)
    625 
    626     def __call__(self, *args, **meta):

~/.local/lib/python3.7/site-packages/triton/testing.py in do_bench(fn, warmup, rep, grad_to_none, percentiles)
    109 
    110     # Estimate the runtime of the function
--> 111     fn()
    112     torch.cuda.synchronize()
    113     start_event = torch.cuda.Event(enable_timing=True)

~/.local/lib/python3.7/site-packages/triton/code_gen.py in <lambda>()
    621         # augment meta-parameters with tunable ones
    622         current = dict(meta, **config.meta)
--> 623         kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
    624         return triton.testing.do_bench(kernel_call)
    625 

~/.local/lib/python3.7/site-packages/triton/code_gen.py in __call__(self, grid, num_warps, num_stages, force_nc_cache, *wargs, **meta)
    577                 *wargs, device=device, attributes=attributes,
    578                 num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
--> 579                 constants=constants, **meta
    580             )
    581         # pack arguments

~/.local/lib/python3.7/site-packages/triton/code_gen.py in _compile(self, device, attributes, constants, num_warps, num_stages, force_nc_cache, *wargs, **meta)
    548         tt_device = _triton.driver.cu_device(device.index, False)
    549         # Compile to machine code
--> 550         mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages, force_nc_cache)
    551         if shared_mem > tt_device.max_shared_memory():
    552             raise  OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory")

RuntimeError: CUDA: Error- invalid ptx
@ptillet
Copy link
Collaborator

ptillet commented Aug 5, 2021

Hey!

FP16 on Pascal GPUs is not supported. More generally compute capability < 70 is not supported, though you can make it work by changing the datatype to fp16 and using block sizes that don't overflow shared memory

@theevann
Copy link
Author

theevann commented Aug 5, 2021

Thank you for your quick answer!
I did not realize this GPU was having compute capabilites < 7.0

I tested with a GPU V100S-PCIE-32GB which has compute capabilities 7.0
It does not error anymore but the output of the multiplication is often incorrect.
triton.testing.allclose returns False and the result of the matmul changes even with the fixed seed...

@ptillet
Copy link
Collaborator

ptillet commented Aug 5, 2021

Hmm this is very odd, given that the CI happens on a V100 and reliably passes. Is it with a fresh clone of the repo?

@theevann
Copy link
Author

theevann commented Aug 5, 2021

No, it is a pip install.

For some reason, I actually had changed the matrix shape, which caused the error.
With the initial shape, it is fine.

I changed the example code to reapply the mm a second time:

matrix_size = 512

torch.manual_seed(0)
a = torch.randn((matrix_size , matrix_size ), device='cuda', dtype=torch.float16)
b = torch.randn((matrix_size , matrix_size ), device='cuda', dtype=torch.float16)

c_0 = matmul(a, b, activation=None)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(triton.testing.allclose(c_0, c_1))

c_0 = matmul(a, b, activation=None)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(triton.testing.allclose(c_0, c_1))

When the matrix shape is a power of 2, the first and second mm are equal to torch.matmul.
When the matrix shape is not a power of 2 (eg. 500), the first mm is equal to torch.matmul but not the second...

When matrix_size is any non power of two, the second triton.testing.allclose print is False.

@ptillet
Copy link
Collaborator

ptillet commented Aug 5, 2021

Ah yes, this is expected in the tutorial. For a more robust matmul you can refer to https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py

Look in particular at the following lines

        if META['EVEN_K']:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)

in the tutorials it's not there, which means that the kernel will accumulate out-of-bounds memory elements when K isn't a multiple of BLOCK_K

@theevann
Copy link
Author

theevann commented Aug 5, 2021

Thank you much for you reactivity !
Maybe write a line about this in the tutorial as I expect other may fall into this ?

@theevann theevann closed this as completed Aug 5, 2021
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
Add the public llvm-spirv component to the Triton XPU backend. 
Support to build the XPU backend friendly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants