Skip to content

Commit

Permalink
Always enable PackGQA if PagedKV to reduce compilation and bin size
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 10, 2025
1 parent a84a237 commit 40fa35a
Show file tree
Hide file tree
Showing 135 changed files with 75 additions and 720 deletions.
12 changes: 6 additions & 6 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
PAGEDKV_SWITCH(params.page_table, PagedKV, [&] {
PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
// Always enable PackGQA for Sm8x to reduce compilation
static constexpr bool PackGQA = PackGQA_ || Arch < 90;
// Always enable PackGQA for Sm8x or PagedKV to reduce compilation
static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKV;
SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
if (!params.is_e4m3) {
if (params.is_bf16) {
Expand Down Expand Up @@ -369,9 +369,9 @@ void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream) {
}

inline bool get_pack_gqa(Flash_fwd_params const& params) {
// Always enable PackGQA for Sm8x to reduce compilation and binary size.
// Has almost no effect on speed.
if (params.arch < 90) { return true; }
// Always enable PackGQA for Sm8x or PagedKV to reduce compilation and binary size.
// Has little effect on speed.
if (params.arch < 90 || params.page_table) { return true; }
#ifdef FLASHATTENTION_DISABLE_PACKGQA
return false;
#else
Expand Down Expand Up @@ -838,7 +838,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
#endif
#ifdef FLASHATTENTION_DISABLE_PACKGQA
TORCH_CHECK(params.arch < 90 || !params.pack_gqa, "This flash attention build does not support pack_gqa.");
TORCH_CHECK(!params.pack_gqa || params.arch < 90 || params.page_table, "This flash attention build does not support pack_gqa.");
#endif
#ifdef FLASHATTENTION_DISABLE_PAGEDKV
TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV.");
Expand Down
8 changes: 5 additions & 3 deletions hopper/generate_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ class Kernel:
def template(self) -> str:
if self.direction == "fwd":
if self.sm == 90:
# Always enable PackGQA for PagedKV to reduce compilation
packgqa = self.packgqa or self.paged_kv
return KERNEL_IMPL_TEMPLATE_FWD_SM90.format(
ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),
SOFTCAP=str(self.softcap).lower(), PACKGQA=str(self.packgqa).lower()
SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower()
)
else:
# Always enable PackGQA for Sm8x to reduce compilation
Expand Down Expand Up @@ -126,9 +128,9 @@ def filename(self) -> str:

def get_all_kernels() -> List[Kernel]:
for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
# We always enable PackGQA for Sm8x so we should just pass in packgqa=False
# We always enable PackGQA for Sm8x and PagedKV so we should just pass in packgqa=False
# to avoid the `_packgqa` in the filename.
if sm < 90 and packgqa:
if packgqa and (sm < 90 or (sm >= 90 and paged_kv)):
continue
if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")
Expand Down

This file was deleted.

2 changes: 1 addition & 1 deletion hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

2 changes: 1 addition & 1 deletion hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

2 changes: 1 addition & 1 deletion hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM128
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

2 changes: 1 addition & 1 deletion hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM192
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM192
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM192
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM192
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

2 changes: 1 addition & 1 deletion hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM192
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM192
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
#include "flash_fwd_launch_template.h"

#ifndef FLASHATTENTION_DISABLE_HDIM192
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
#endif
Loading

0 comments on commit 40fa35a

Please sign in to comment.