Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace mamba2 mamba_chunk_scan_combined triton kernel by simple_gla triton kernel #49

Merged
merged 6 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 73 additions & 92 deletions fla/ops/simple_gla/chunk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023, Yu Zhang, Songlin Yang

from typing import Tuple
from typing import Optional, Tuple

import torch
import triton
Expand All @@ -11,23 +11,14 @@
from fla.utils import contiguous


@torch.jit.script
def normalize_output(q, k, o):
k = k.transpose(-2, -1)
k = k.cumsum(-1)
k = k.transpose(-2, -1)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-5)


@triton.jit
def chunk_simple_gla_fwd_kernel_h(
k,
v,
h,
g,
initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]
final_state, # final state of the chunk [B, H, D_head_K, D_head_V]
h0,
ht,
s_qk_h,
s_qk_t,
s_qk_d,
Expand All @@ -36,7 +27,6 @@ def chunk_simple_gla_fwd_kernel_h(
s_vo_d,
s_h_h,
s_h_t,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand All @@ -53,17 +43,13 @@ def chunk_simple_gla_fwd_kernel_h(
b_h = tl.zeros([BK, BV], dtype=tl.float32)

if USE_INITIAL_STATE:
p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V,
(K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)

for i_t in range(NT):
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))

tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
Expand All @@ -72,13 +58,12 @@ def chunk_simple_gla_fwd_kernel_h(
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BK, BV]
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
b_h *= tl.math.exp2(b_g_last)
b_h *= tl.exp(b_g_last)
b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT))
b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)
b_h += tl.dot(b_k, (b_v * tl.exp(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False)

if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(
final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))


Expand All @@ -99,7 +84,6 @@ def chunk_simple_gla_fwd_kernel_o(
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand All @@ -115,12 +99,9 @@ def chunk_simple_gla_fwd_kernel_o(
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_s = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(
k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V,
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))

# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
Expand All @@ -135,16 +116,14 @@ def chunk_simple_gla_fwd_kernel_o(

p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
b_g = tl.load(p_g)
b_o = b_o * tl.math.exp2(b_g)[:, None]
b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :])
b_o = b_o * tl.exp(b_g)[:, None]
b_s = b_s * tl.exp(b_g[:, None] - b_g[None, :])
b_s = tl.where(m_s, b_s, 0)

p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V),
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V),
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))


Expand All @@ -163,7 +142,6 @@ def chunk_simple_gla_bwd_kernel_dh(
s_h_h,
s_h_t,
scale,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand All @@ -177,22 +155,18 @@ def chunk_simple_gla_bwd_kernel_dh(
# [BK, BV]
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
for i_t in range(NT - 1, -1, -1):
p_q = tl.make_block_ptr(
q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V,
(K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))

tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
# [BK, BT]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T +
i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)
b_q = (b_q * scale * tl.exp(tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype)
# [BT, V]
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BK, BV]
b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1))
b_dh *= tl.exp(tl.load(g + i_bh * T + i_t * BT + BT - 1))
b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)


Expand All @@ -217,8 +191,6 @@ def chunk_simple_gla_bwd_kernel_dqkv(
s_h_h,
s_h_t,
scale,
B: tl.constexpr,
H: tl.constexpr,
T: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand All @@ -231,35 +203,28 @@ def chunk_simple_gla_bwd_kernel_dqkv(
n_bh = tl.num_programs(2)
o_i = tl.arange(0, BT)

p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T),
(s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K),
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))

b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_s = tl.dot(b_k, b_q, allow_tf32=False)
p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
b_g = tl.load(p_g)
b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1)
mask = tl.math.exp2(b_g[None, :] - b_g[:, None])
mask = tl.exp(b_g[None, :] - b_g[:, None])
mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0)
b_s = b_s * mask

b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_ds = tl.zeros([BT, BT], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t),
(i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(
do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V),
(s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V),
(s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))
p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
Expand All @@ -273,21 +238,19 @@ def chunk_simple_gla_bwd_kernel_dqkv(
b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale
b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
# [BT, BV]
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \
tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.exp(-b_g + b_g_last)[:, None]
b_dv += tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))

b_dq = b_dq * tl.math.exp2(b_g)[:, None]
b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None]
b_dq = b_dq * tl.exp(b_g)[:, None]
b_dk = b_dk * tl.exp(-b_g + b_g_last)[:, None]
b_ds = b_ds * tl.trans(mask)
b_ds = b_ds.to(b_k.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k, allow_tf32=False)
b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K),
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K),
(s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

Expand All @@ -297,20 +260,17 @@ class SimpleGLAFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
@contiguous
def forward(ctx, q, k, v, g, initial_state, output_final_state):
def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(64, triton.next_power_of_2(K)), min(
64, triton.next_power_of_2(V))
BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = K ** -0.5
num_stages = 1

BT = 64
assert T % BT == 0, 'sequence length must be divisible by BT'
g = g.reshape(B, H, -1, BT)
g = g.cumsum(-1) * 1.44269504
g = g.cumsum(-1)
g = g.reshape(B, H, -1)

final_state = None
Expand All @@ -324,7 +284,7 @@ def forward(ctx, q, k, v, g, initial_state, output_final_state):
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
num_warps=num_warps,
Expand All @@ -338,28 +298,29 @@ def forward(ctx, q, k, v, g, initial_state, output_final_state):
v.stride(1), v.stride(2), v.stride(3),
h.stride(1), h.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
num_warps=num_warps,
num_stages=num_stages
)

ctx.save_for_backward(q, k, v, h, g)
ctx.scale = scale
return o.to(q.dtype), final_state

@staticmethod
@custom_bwd
@contiguous
def backward(ctx, do, d_ht=None):
def backward(ctx, do, dht=None):
q, k, v, h, g = ctx.saved_tensors

B, H, T, K, V = *q.shape, v.shape[-1]
BT = 64
BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min(
32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
BK = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K))
BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))
NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)
num_stages = 1
num_warps = 4 if BK == 64 else 2
scale = K ** -0.5
num_stages = 1
scale = ctx.scale

dh = q.new_empty(B, H, NT * K, V)
grid = (NK, NV, B * H)
Expand All @@ -369,7 +330,7 @@ def backward(ctx, do, d_ht=None):
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
Expand All @@ -385,7 +346,7 @@ def backward(ctx, do, d_ht=None):
v.stride(1), v.stride(2), v.stride(3),
dh.stride(1), dh.stride(2),
scale,
B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
num_warps=num_warps,
num_stages=num_stages
)
Expand All @@ -405,11 +366,31 @@ def chunk_simple_gla(
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor, # log decay
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if initial_state is not None:
initial_state = initial_state.detach()
r"""
Args:
q (torch.Tensor):
queries of shape `(B, H, T, K)`
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
g (torch.Tensor):
Forget gates of shape `(B, H, T)` applied to keys.
Compared to GLA, the gating is head-wise instead of elementwise.
scale (Optional[int]):
Scale factor for the attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
"""
if scale is None:
scale = k.shape[-1] ** -0.5
g = g.float()
o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state)
o, final_state = SimpleGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state)
return o, final_state
Loading
Loading