Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit multiply-reduce GEMM kernel #621

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/perf-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,20 @@ fp32, bf16 and f8 (both e5m2 and e4m3) datatypes.
## `03-matrix-multiplication-stream-k.py`

This script contains the GEMM kernel that implements [stream-k](https://arxiv.org/abs/2301.03598)

## `multreduce_matmul_kernel.py`

Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. Such
small block sizes aren't natively supported by `tl.dot` operator.

Despite being numerically correct, this kernel performed worse than a corresponding GEMM kernel that
used `tl.dot` with minimum block size equal to $16$:

**MI300 Results for FP16:**

| trans | M | N | K | Dot GiBps | Multiply-Reduce GiBps | Speedup |
|:--------|----:|------:|------:|------------------:|------------------------:|----------:|
| TN | 1 | 8192 | 28672 | 3491.33 | 2869.87 | 0.82 |
| TN | 1 | 6144 | 6144 | 3858.22 | 2673.33 | 0.69 |
| TN | 1 | 4096 | 4096 | 2352.54 | 1680.93 | 0.71 |
| TN | 2 | 16384 | 16384 | 3412.17 | 3318.44 | 0.97 |
73 changes: 73 additions & 0 deletions python/perf-kernels/multreduce_matmul_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import triton
import triton.language as tl


# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes.
# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch).
@triton.jit
def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
BIAS: tl.constexpr, EVEN_K: tl.constexpr):
pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
if GROUP_SIZE_M == 1:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
if SPLIT_K == 1:
offs_k = tl.arange(0, BLOCK_SIZE_K)
else:
offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
if BIAS:
bias_ptrs = bias_ptr + offs_am * stride_bias
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0)
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# TODO: `min_dot_size` isn't always (16, 16, 16), it's in fact architecture
# dependent and data type dependent. Please check this source file:
# `${TRITON_ROOT}/third_party/amd/backend/compiler.py`
# (look for `min_dot_size` function).
# Q: How can we implement this `if` statement more precisely?
# A: Maybe we can use an approach similar to `EVEN_K`, in which the kernel caller
# inspect all relevant information and provides a Boolean flag for the kernel.
if (BLOCK_SIZE_M < 16 or BLOCK_SIZE_N < 16) or BLOCK_SIZE_K < 16:
# Explicit multiply-reduce for small block sizes.
a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype)
b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype)
accumulator += tl.sum(a * b, axis=1)
else:
# Triton dot product for other block sizes.
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
c = accumulator.to(c_ptr.type.element_ty)
if BIAS:
c += bias[:, None]
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if SPLIT_K == 1:
tl.store(c_ptrs, c, mask=c_mask)
else:
tl.atomic_add(c_ptrs, c, mask=c_mask)
Loading