Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cutlass support for blackwell fp8 gemm #13798

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

kushanam
Copy link
Contributor

@kushanam kushanam commented Feb 25, 2025

This PR adds support for cutlass blackwell gemm for fp8
A couple of notes:
1- The sm100_fp8_config_default is the only supported config for now with static tile and cluster shapes. Consequent PRs will support optimized configs for other shapes. Also 2xsm gemm will be added.
2- added default constructors for c2x and c3x kernels, some build environments treat the missed constructor warning as error

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Feb 25, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @kushanam! Looks good overall, left a few comments/questions

Comment on lines 76 to 83
using StrideC = typename Gemm::StrideC;
using StrideC = typename Gemm::GemmKernel::StrideC;

StrideA a_stride{lda, cute::Int<1>{}, 0};
StrideB b_stride{ldb, cute::Int<1>{}, 0};
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
// StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
StrideC c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(ldc, 1, 0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this change?

Comment on lines 105 to 103
using ElementD = ElementD_;
using LayoutD = cutlass::layout::ColumnMajor;
static constexpr int AlignmentD =
128 / cutlass::sizeof_bits<ElementD_>::value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've been setting AlignmentD to 4 to reduce the alignment requirement of these kernels. Can this be 4 instead of 8? Also, do you know what the performance considerations to this are?

Copy link
Contributor Author

@kushanam kushanam Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally 128-bit alignment (i.e. 8 for 16-bit data types) is required for best TMA perf, 4 might work but perf will suffer and It's the same between Hopper and Blackwell

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance, this looks very similar to csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp. Why can't it be the same code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BW doesn't like EpilogueDescriptor. On the other hand the only use for EpilogueDescriptor in scaled_mm_epilogues_c3x seems to be tile shapes, so alternatively we could get rid of it all together and keep all under the same file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's the only use.

we could get rid of it all together and keep all under the same file

I think that's the right move. Could you make that change in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

CMakeLists.txt Outdated
Comment on lines 302 to 310
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we'll have to guard against compilation of scaled_mm_sm100_fp8.cu when CUDA < 12.8

Copy link

mergify bot commented Feb 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kushanam.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 27, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks great to me now! Could you merge in the changes from latest main?

Signed-off-by: Tyler Michael Smith <[email protected]>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 2, 2025 20:12
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 2, 2025
Comment on lines 75 to 94
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

int M = a.size(0), N = b.size(1), K = a.size(1);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");

// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to ifdef this out when CUDA < 12.8

Comment on lines 127 to 134
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
if (version_num >= 90 && version_num < 100) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
} else if (version_num >= 100) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing linker errors in the CI. Need to guard against calling cutlass_scaled_mm_sm100 when CUDA < 12.8

Something like this:

#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X

#if defined CUDA_VERSION && CUDA_VERSION >= 12800
  if (version_num >= 90) {
    cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
    return;
  }
#else
  if (version_num >= 90 && version_num < 100) {
    cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
    return;
  } else if (version_num >= 100) {
    cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
    return;
  }
#endif

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some linker errors when CUDA < 12.8 that need to be addressed. (I left some inline comments)

This made me realize we need to add guards/fallbacks/warnings when running on a Blackwell GPU but using a kernel compiled with CUDA < 12.8, as currently we will try to run the non-forward-compatible kernels for sm90a. That doesn't need to happen in this PR but @kushanam do you have any thoughts there?

Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
auto-merge was automatically disabled March 4, 2025 02:51

Head branch was pushed to by a user without write access

pathorn pushed a commit to deepinfra/vllm that referenced this pull request Mar 4, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) March 4, 2025 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants