-
-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cutlass support for blackwell fp8 gemm #13798
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution @kushanam! Looks good overall, left a few comments/questions
using StrideC = typename Gemm::StrideC; | ||
using StrideC = typename Gemm::GemmKernel::StrideC; | ||
|
||
StrideA a_stride{lda, cute::Int<1>{}, 0}; | ||
StrideB b_stride{ldb, cute::Int<1>{}, 0}; | ||
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; | ||
// StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; | ||
StrideC c_stride = | ||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(ldc, 1, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this change?
using ElementD = ElementD_; | ||
using LayoutD = cutlass::layout::ColumnMajor; | ||
static constexpr int AlignmentD = | ||
128 / cutlass::sizeof_bits<ElementD_>::value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've been setting AlignmentD to 4 to reduce the alignment requirement of these kernels. Can this be 4 instead of 8? Also, do you know what the performance considerations to this are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally 128-bit alignment (i.e. 8 for 16-bit data types) is required for best TMA perf, 4 might work but perf will suffer and It's the same between Hopper and Blackwell
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first glance, this looks very similar to csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
. Why can't it be the same code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BW doesn't like EpilogueDescriptor
. On the other hand the only use for EpilogueDescriptor
in scaled_mm_epilogues_c3x
seems to be tile shapes, so alternatively we could get rid of it all together and keep all under the same file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep that's the only use.
we could get rid of it all together and keep all under the same file
I think that's the right move. Could you make that change in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
CMakeLists.txt
Outdated
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}") | ||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) | ||
set(SRCS | ||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" | ||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" | ||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" | ||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" | ||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") | ||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu" | ||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we'll have to guard against compilation of scaled_mm_sm100_fp8.cu
when CUDA < 12.8
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks great to me now! Could you merge in the changes from latest main?
Signed-off-by: Tyler Michael Smith <[email protected]>
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, | ||
torch::Tensor const& b, | ||
torch::Tensor const& a_scales, | ||
torch::Tensor const& b_scales, | ||
std::optional<torch::Tensor> const& bias) { | ||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||
|
||
int M = a.size(0), N = b.size(1), K = a.size(1); | ||
TORCH_CHECK( | ||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && | ||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)), | ||
"Currently, block scaled fp8 gemm is not implemented for Blackwell"); | ||
|
||
// Standard per-tensor/per-token/per-channel scaling | ||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); | ||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, | ||
"Currently, only fp8 gemm is implemented for Blackwell"); | ||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to ifdef this out when CUDA < 12.8
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X | ||
if (version_num >= 90) { | ||
if (version_num >= 90 && version_num < 100) { | ||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias); | ||
return; | ||
} else if (version_num >= 100) { | ||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias); | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is causing linker errors in the CI. Need to guard against calling cutlass_scaled_mm_sm100
when CUDA < 12.8
Something like this:
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined CUDA_VERSION && CUDA_VERSION >= 12800
if (version_num >= 90) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#else
if (version_num >= 90 && version_num < 100) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
} else if (version_num >= 100) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some linker errors when CUDA < 12.8 that need to be addressed. (I left some inline comments)
This made me realize we need to add guards/fallbacks/warnings when running on a Blackwell GPU but using a kernel compiled with CUDA < 12.8, as currently we will try to run the non-forward-compatible kernels for sm90a. That doesn't need to happen in this PR but @kushanam do you have any thoughts there?
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Head branch was pushed to by a user without write access
This PR adds support for cutlass blackwell gemm for fp8
A couple of notes:
1- The
sm100_fp8_config_default
is the only supported config for now with static tile and cluster shapes. Consequent PRs will support optimized configs for other shapes. Also 2xsm gemm will be added.2- added default constructors for c2x and c3x kernels, some build environments treat the missed constructor warning as error