-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity][BYOC] Integrate fp16 A - int4 B GEMM kernel from FasterTransformer into CUTLASS BYOC #15111
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the excellent contribution, Masa!
This looks good to me, I have some minor questions/comments.
sinfo_args=(R.Tensor((64, 64), dtype="int8"),), | ||
) | ||
lv3: R.Tensor((128,), dtype="float16") = lv[1] | ||
lv6 = R.call_tir( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering, in the future, once we settle down representative quantization schemes, would it make sense to introduce relax-level decode op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a bit hard to define what "representative Q scheme" is going to be. Even for the simple one expected by FT, the weight needs to be packed in a very specific way.
) | ||
lv1 = lv[0] | ||
lv2 = R.call_pure_packed( | ||
"cutlass.ft_preprocess_weight_int4", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we always have to apply this preprocess, can we do the preprocess and then dump the weight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's possible, encode + preprocess and dump. In practice, if we do LiftTransformParams
pass, this preprocess is already done as part of running transform_params
function (together with encode). So we are already doing what you described.
packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size()); | ||
// multiply cols by 2 since the "col" params in preprocess_weights refers to the column of | ||
// the unpacked weight. | ||
fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), rows, cols * 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe good to add some documentation about why we need this preprocess, why it is happening on CPU, and if it is possible/better to move it on GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this function is meant to be run at compile time. It needs to run on CPU since the preprocess functions are defined by FT in C++. I don't find them to be slow, so given the amount of effort required by GPU porting, I don't think it's worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some doc
lv2 = R.call_pure_packed( | ||
"cutlass.ft_preprocess_weight_int4", | ||
lv1, | ||
80, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can sm be inferred automatically? It doesn't look friendly to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general one might want to compile for a different GPU than the one in the compilation environment. cutlass.ft_preprocess_weight_int4
needs to be defined in tvm_runtime
so we cannot do Target::Current()
from there.
That said, if I can define a Python utility function with an optional sm argument, that can be called from TVMScript, that would be best. But is it possible to call an arbitrary python function from TVMScript?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also CUTLASS users need to be very careful with the target sm anyway, so I hope it won't be a big deal in practice.
Any other comments? Otherwise, can someone merge this? |
…ormer into CUTLASS BYOC (apache#15111) * Initial FT kernel integration * update rev * wip * wip * add more check * add residual test * wip * wip * fix * update test * properly generate residual * black * update condition * missing doc * pylint * add missing apache header * format * add some doc on preprocess
Hi @masahi , I was wondering if you had benchmarked against e.g. available GPTQ cuda or triton-based kernels. Thank you! |
I have compared against the exllama int4 kernel for decoding. It seems the FT GEMM kernel is overkill for mat x vec multiply - the custom mat x vec kernel from exllama is much faster.
|
Thank you, I'm trying to connect the dots between the different opinions I hear around using mma.sync.aligned op (which from what I understand of cutlass, the fastertransformer extension should come down to) vs not for GEMV, and your results are in line with my current intuition. Now the only mystery remaining to me is why pytorch chooses a tensor core based kernel for GEMV in fp16 * fp16. Maybe there could be opportunities to have a cutlass extension based rather on https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/gemv.h / https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/gemv_batched_strided.h / for the decoding, along with a gemm kernel for the prefill. |
@masahi what was the dimensions of the input and the GPU you tested this on? Thanks |
This was the profiler output on one e2e inference on vicuna 7B decoder. The GPU was RTX 4080. |
It seems that the performance of the default cutlass tile does not perform well on compute bound shape gemm. Do we have any plans to support tile tuning for this backend? |
How do you know? The kernel uses a runtime heuristic to select a tile config. For now I'm hoping that this heuristic is "good enough". |
thanks, I see. I mentioned that because I benchmarked the performance on the Llama-70b GEMM (cuBLAS FP16xFP16 vs. Relax.Cutlass fp16xint4) on A6000 48G wth cuda 12.1.1 installed.
|
@LeiWang1999 Thank you for sharing this. This is very interesting data point! May I ask how I can reproduce these numbers? Did you exhaustively search over the possible tile combinations? |
hi, @sunggg code to reproduce:
glad if it would help. |
@LeiWang1999 yes, we have done similar perf comparison against cublas fp16, and indeed we found that the FT int4 is faster only for small M. This is not surprising since a larger M amortizes the cost of loading a large weight over each row (which may correspond to "batch" or "token" in llm inference) and int4 involves dequantize overhead. The NV developers also talked about this point in their GTC23 presentation. Here is one of their slides. |
This PR integrates the kernel from https://github.com/tlc-pack/cutlass_fpA_intB_gemm. In addition to the original features in FasterTransformer, this derived implementation also supports residual fusion by tlc-pack/cutlass_fpA_intB_gemm#1.
(K, N)
matrix is scaled by the corresponding element from a scale vector of shape(N,)
. Each scale is calculated over the K axis.(K, N / 2)
.contrib/cutlass/weight_preprocess.cc
, which in turn does various preprocessings on the weight implemented in https://github.com/tlc-pack/cutlass_fpA_intB_gemm/blob/main/cutlass_kernels/cutlass_preprocessors.cc. These processes are not well documented, so we wrap them into a packed function and use it as a black box.The test case demonstrates the overall integration flow. It is interesting to note that, although encode / decode functions are expressed as TIR, we can still use the graph-based BYOC to integrate this kernel.
@sunggg @yelite @vinx13 @Hzfengsy @yzh119