Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][BYOC] Integrate fp16 A - int4 B GEMM kernel from FasterTransformer into CUTLASS BYOC #15111

Merged
merged 18 commits into from
Jun 20, 2023

Conversation

masahi
Copy link
Member

@masahi masahi commented Jun 15, 2023

This PR integrates the kernel from https://github.com/tlc-pack/cutlass_fpA_intB_gemm. In addition to the original features in FasterTransformer, this derived implementation also supports residual fusion by tlc-pack/cutlass_fpA_intB_gemm#1.

  • It only supports int4 weight for now, while the kernel supports int8 as well.
  • The kernel expects a simple symmetric quantization scheme where each column of (K, N) matrix is scaled by the corresponding element from a scale vector of shape (N,). Each scale is calculated over the K axis.
  • The weight matrix needs to be packed into int8 storage with shape (K, N / 2).
  • The packed weight needs to be preprocessed by a special function defined in contrib/cutlass/weight_preprocess.cc, which in turn does various preprocessings on the weight implemented in https://github.com/tlc-pack/cutlass_fpA_intB_gemm/blob/main/cutlass_kernels/cutlass_preprocessors.cc. These processes are not well documented, so we wrap them into a packed function and use it as a black box.

The test case demonstrates the overall integration flow. It is interesting to note that, although encode / decode functions are expressed as TIR, we can still use the graph-based BYOC to integrate this kernel.

@sunggg @yelite @vinx13 @Hzfengsy @yzh119

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jun 15, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the excellent contribution, Masa!
This looks good to me, I have some minor questions/comments.

sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
)
lv3: R.Tensor((128,), dtype="float16") = lv[1]
lv6 = R.call_tir(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering, in the future, once we settle down representative quantization schemes, would it make sense to introduce relax-level decode op?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bit hard to define what "representative Q scheme" is going to be. Even for the simple one expected by FT, the weight needs to be packed in a very specific way.

)
lv1 = lv[0]
lv2 = R.call_pure_packed(
"cutlass.ft_preprocess_weight_int4",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we always have to apply this preprocess, can we do the preprocess and then dump the weight?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's possible, encode + preprocess and dump. In practice, if we do LiftTransformParams pass, this preprocess is already done as part of running transform_params function (together with encode). So we are already doing what you described.

packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
// multiply cols by 2 since the "col" params in preprocess_weights refers to the column of
// the unpacked weight.
fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), rows, cols * 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe good to add some documentation about why we need this preprocess, why it is happening on CPU, and if it is possible/better to move it on GPU?

Copy link
Member Author

@masahi masahi Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this function is meant to be run at compile time. It needs to run on CPU since the preprocess functions are defined by FT in C++. I don't find them to be slow, so given the amount of effort required by GPU porting, I don't think it's worth it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some doc

lv2 = R.call_pure_packed(
"cutlass.ft_preprocess_weight_int4",
lv1,
80,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can sm be inferred automatically? It doesn't look friendly to users.

Copy link
Member Author

@masahi masahi Jun 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general one might want to compile for a different GPU than the one in the compilation environment. cutlass.ft_preprocess_weight_int4 needs to be defined in tvm_runtime so we cannot do Target::Current() from there.

That said, if I can define a Python utility function with an optional sm argument, that can be called from TVMScript, that would be best. But is it possible to call an arbitrary python function from TVMScript?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also CUTLASS users need to be very careful with the target sm anyway, so I hope it won't be a big deal in practice.

@masahi
Copy link
Member Author

masahi commented Jun 20, 2023

Any other comments? Otherwise, can someone merge this?

@vinx13 vinx13 merged commit 84ce484 into apache:unity Jun 20, 2023
junrushao pushed a commit to junrushao/tvm that referenced this pull request Jun 22, 2023
…ormer into CUTLASS BYOC (apache#15111)

* Initial FT kernel integration

* update rev

* wip

* wip

* add more check

* add residual test

* wip

* wip

* fix

* update test

* properly generate residual

* black

* update condition

* missing doc

* pylint

* add missing apache header

* format

* add some doc on preprocess
@fxmarty
Copy link

fxmarty commented Jun 26, 2023

Hi @masahi , I was wondering if you had benchmarked against e.g. available GPTQ cuda or triton-based kernels. Thank you!

@masahi
Copy link
Member Author

masahi commented Jun 30, 2023

I have compared against the exllama int4 kernel for decoding. It seems the FT GEMM kernel is overkill for mat x vec multiply - the custom mat x vec kernel from exllama is much faster.

Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name   
                                              
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ---------------------------------------------------------------------------------------------------- 
     59.6    3,032,906,823     81,824   37,066.2   26,882.0    26,528    78,275     12,595.2  void cutlass::Kernel<cutlass::gemm::kernel::GemmFpAIntB
<cutlass::gemm::threadblock::DqMmaMultistage…       
     30.7    1,561,633,485     32,704   47,750.5   47,601.5    26,657   217,961     20,785.9  void cutlass::Kernel<cutlass::gemm::kernel::GemmFpAIntB
<cutlass::gemm::threadblock::DqMmaMultistage…                                                                                                        
      4.6      233,826,480     16,416   14,243.8   14,016.0     4,257    22,881      4,835.9  void attention_kernel_batched_impl<AttentionKernel<cutl
ass::half_t, cutlass::arch::Sm80, (bool)1, (…                                                                                                        
      1.2       59,620,306        513  116,218.9  116,196.0   114,597   120,773        660.2  void cutlass::Kernel<cutlass::gemm::kernel::GemmFpAIntB
 <cutlass::gemm::threadblock::DqMmaMultistage…
Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name   
                                              
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ---------------------------------------------------------------------------------------------------- 
     77.0    1,563,943,842     57,344   27,273.0   18,111.0    16,351    42,174     11,590.3  void q4_matmul_kernel<(bool)1, (bool)1, (bool)0>(const 
__half *, const unsigned int *, __half *, co…		

@fxmarty
Copy link

fxmarty commented Jun 30, 2023

Thank you, I'm trying to connect the dots between the different opinions I hear around using mma.sync.aligned op (which from what I understand of cutlass, the fastertransformer extension should come down to) vs not for GEMV, and your results are in line with my current intuition.

Now the only mystery remaining to me is why pytorch chooses a tensor core based kernel for GEMV in fp16 * fp16.

Maybe there could be opportunities to have a cutlass extension based rather on https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/gemv.h / https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/gemv_batched_strided.h / for the decoding, along with a gemm kernel for the prefill.

@cyang49
Copy link

cyang49 commented Aug 22, 2023

I have compared against the exllama int4 kernel for decoding. It seems the FT GEMM kernel is overkill for mat x vec multiply - the custom mat x vec kernel from exllama is much faster.

Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name   
                                              
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ---------------------------------------------------------------------------------------------------- 
     59.6    3,032,906,823     81,824   37,066.2   26,882.0    26,528    78,275     12,595.2  void cutlass::Kernel<cutlass::gemm::kernel::GemmFpAIntB
<cutlass::gemm::threadblock::DqMmaMultistage…       
     30.7    1,561,633,485     32,704   47,750.5   47,601.5    26,657   217,961     20,785.9  void cutlass::Kernel<cutlass::gemm::kernel::GemmFpAIntB
<cutlass::gemm::threadblock::DqMmaMultistage…                                                                                                        
      4.6      233,826,480     16,416   14,243.8   14,016.0     4,257    22,881      4,835.9  void attention_kernel_batched_impl<AttentionKernel<cutl
ass::half_t, cutlass::arch::Sm80, (bool)1, (…                                                                                                        
      1.2       59,620,306        513  116,218.9  116,196.0   114,597   120,773        660.2  void cutlass::Kernel<cutlass::gemm::kernel::GemmFpAIntB
 <cutlass::gemm::threadblock::DqMmaMultistage…
Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name   
                                              
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ---------------------------------------------------------------------------------------------------- 
     77.0    1,563,943,842     57,344   27,273.0   18,111.0    16,351    42,174     11,590.3  void q4_matmul_kernel<(bool)1, (bool)1, (bool)0>(const 
__half *, const unsigned int *, __half *, co…		

@masahi what was the dimensions of the input and the GPU you tested this on? Thanks

@masahi
Copy link
Member Author

masahi commented Aug 22, 2023

This was the profiler output on one e2e inference on vicuna 7B decoder. The GPU was RTX 4080.

@LeiWang1999
Copy link
Contributor

It seems that the performance of the default cutlass tile does not perform well on compute bound shape gemm. Do we have any plans to support tile tuning for this backend?

@masahi
Copy link
Member Author

masahi commented Sep 7, 2023

It seems that the performance of the default cutlass tile does not perform well on compute bound shape gemm

How do you know? The kernel uses a runtime heuristic to select a tile config. For now I'm hoping that this heuristic is "good enough".

https://github.com/tlc-pack/cutlass_fpA_intB_gemm/blob/main/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h#L818-L826

@LeiWang1999
Copy link
Contributor

LeiWang1999 commented Sep 7, 2023

thanks, I see. I mentioned that because I benchmarked the performance on the Llama-70b GEMM (cuBLAS FP16xFP16 vs. Relax.Cutlass fp16xint4) on A6000 48G wth cuda 12.1.1 installed.

M N K cublas cutlass-fpa-intb speedup
1 1024 8192 0.030588 0.046896935 0.652236035
1 8192 8192 0.199339 0.087285042 2.28376673
1 8192 28672 1.055949 0.187683105 5.626232851
1 28672 8192 0.672551 0.192594528 3.492055583
16 1024 8192 0.036352 0.095558167 0.380417524
16 8192 8192 0.19456 0.086522102 2.248674049
16 8192 28672 0.666054 0.187945366 3.54386972
16 28672 8192 0.669696 0.19159317 3.495406297
32 1024 8192 0.037856 0.101447105 0.373159996
32 8192 8192 0.196592 0.101852417 1.930165321
32 8192 28672 0.664686 0.319719315 2.078966443
32 28672 8192 0.67072 0.279974937 2.395642936
64 1024 8192 0.05376 0.170302391 0.31567378
64 8192 8192 0.199168 0.175619125 1.134090585
64 8192 28672 0.675594 0.57182312 1.181474212
64 28672 8192 0.681984 0.283193588 2.408190141
128 1024 8192 0.078336 0.051283836 1.527498838
128 8192 8192 0.238592 0.178599358 1.335906254
128 8192 28672 0.739888 0.575089455 1.286561606
128 28672 8192 0.714752 0.521111488 1.371591367
1024 1024 8192 0.28323 0.171804428 1.648558319
1024 8192 8192 1.158315 1.086783409 1.065819275
1024 8192 28672 4.166997 4.37412262 0.952647604
1024 28672 8192 3.709171 3.643035889 1.01815373
4096 1024 8192 0.597701 0.690698624 0.865357094
4096 8192 8192 4.334677 6.313800812 0.686540065
4096 8192 28672 16.15016 23.39763641 0.690247328
4096 28672 8192 14.95114 22.15902805 0.674720217
8192 1024 8192 1.200128 1.547074318 0.775740341
8192 8192 8192 8.550848 12.74256706 0.671045949
8192 8192 28672 30.88835 44.36366558 0.696253236
8192 28672 8192 30.19989 44.58220005 0.677397932
16384 1024 8192 2.349141 3.225851059 0.728223751
16384 8192 8192 16.97864 25.81973076 0.657583766
16384 8192 28672 61.27206 91.48414135 0.669756127
16384 28672 8192 61.03514 90.86410999 0.671718869

@sunggg
Copy link
Contributor

sunggg commented Sep 7, 2023

@LeiWang1999 Thank you for sharing this. This is very interesting data point! May I ask how I can reproduce these numbers? Did you exhaustively search over the possible tile combinations?

@LeiWang1999
Copy link
Contributor

hi, @sunggg code to reproduce:

glad if it would help.

@masahi
Copy link
Member Author

masahi commented Sep 7, 2023

@LeiWang1999 yes, we have done similar perf comparison against cublas fp16, and indeed we found that the FT int4 is faster only for small M. This is not surprising since a larger M amortizes the cost of loading a large weight over each row (which may correspond to "batch" or "token" in llm inference) and int4 involves dequantize overhead.

The NV developers also talked about this point in their GTC23 presentation. Here is one of their slides.

スクリーンショット 2023-08-18 8 15 03

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants