Skip to content

Commit

Permalink
Merge pull request #1 from tlc-pack/residual-fusion
Browse files Browse the repository at this point in the history
Add support for residual fusion
  • Loading branch information
masahi authored Jun 14, 2023
2 parents 80856ad + 2356986 commit 323dd4b
Show file tree
Hide file tree
Showing 8 changed files with 792 additions and 54 deletions.

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cutlass_kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64)

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode arch=compute_80,code=sm_80")

add_library(fpA_intB_gemm SHARED fpA_intB_gemm/fpA_intB_gemm_fp16_int4.cu fpA_intB_gemm/fpA_intB_gemm_fp16_int8.cu fpA_intB_gemm.cu cutlass_heuristic.cc cutlass_preprocessors.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/cuda_utils.cc)
add_library(fpA_intB_gemm SHARED fpA_intB_gemm.cu cutlass_heuristic.cc cutlass_preprocessors.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/cuda_utils.cc)
37 changes: 35 additions & 2 deletions cutlass_kernels/fpA_intB_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,51 @@ void gemm_fp16_int4(const half* A,
C, m, n, k, workspace_ptr, workspace_bytes, stream);
}

void gemm_fp16_int4_bias(const half* A,
ActivationType get_activation(const std::string& activation_name) {
if (activation_name == "identity") return ActivationType::Identity;
if (activation_name == "relu") return ActivationType::Relu;
if (activation_name == "silu") return ActivationType::Silu;
// todo: more
return ActivationType::Identity;
}

void gemm_fp16_int4_bias_act(const half* A,
const uint4b_t* B,
const half* weight_scales,
const half* biases,
half* C,
const std::string& activation,
int m, int n, int k, char* workspace_ptr,
size_t workspace_bytes,
cudaStream_t stream) {
CutlassFpAIntBGemmRunner<half, uint4b_t> runner;

runner.gemm_bias_act(A, B, weight_scales, biases,
C, m, n, k, ActivationType::Identity, workspace_ptr, workspace_bytes, stream);
C, m, n, k, get_activation(activation), workspace_ptr, workspace_bytes, stream);
}

void gemm_fp16_int4_bias(const half* A,
const uint4b_t* B,
const half* weight_scales,
const half* biases,
half* C,
int m, int n, int k, char* workspace_ptr,
size_t workspace_bytes,
cudaStream_t stream) {
gemm_fp16_int4_bias_act(A, B, weight_scales, biases, C, "identity", m, n, k, workspace_ptr, workspace_bytes, stream);
}

void gemm_fp16_int4_bias_act_residual(
const half *A, const uint4b_t *B, const half *weight_scales,
const half *biases, const half *residual, half *C, const std::string& activation, const std::string& binary_op,
const std::string& unary_op, int m, int n,
int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream) {
CutlassFpAIntBGemmRunner<half, uint4b_t> runner;

runner.gemm_bias_act_residual(A, B, weight_scales, biases, residual,
C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream);

}


} // namespace fastertransformer
15 changes: 13 additions & 2 deletions cutlass_kernels/fpA_intB_gemm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <string>
#include <cuda_runtime.h>
#include "cutlass/numeric_types.h"
#include "cutlass/half.h"
Expand All @@ -23,7 +24,17 @@ void gemm_fp16_int4(const half *A, const uint4b_t *B, const half *weight_scales,
void gemm_fp16_int4_bias(const half *A, const uint4b_t *B,
const half *weight_scales, const half *biases, half *C,
int m, int n, int k, char *workspace_ptr,
size_t workspace_bytes,
cudaStream_t stream);
size_t workspace_bytes, cudaStream_t stream);

void gemm_fp16_int4_bias_act(const half *A, const uint4b_t *B,
const half *weight_scales, const half *biases,
half *C, const std::string& activation, int m,
int n, int k, char *workspace_ptr,
size_t workspace_bytes, cudaStream_t stream);

void gemm_fp16_int4_bias_act_residual(
const half *A, const uint4b_t *B, const half *weight_scales,
const half *biases, const half *residual, half *C, const std::string& activation, const std::string& binary_op,
const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);

} // namespace fastertransformer
9 changes: 9 additions & 0 deletions cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class CutlassFpAIntBGemmRunner {
const size_t workspace_bytes,
cudaStream_t stream);

void gemm_bias_act_residual(const T *A, const WeightType *B,
const T *weight_scales, const T *biases,
const T *residual, T *C, int m, int n, int k,
const std::string& activation, const std::string& binary_op,
const std::string& unary_op,
char *workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);

// Returns desired workspace size in bytes.
int getWorkspaceSize(const int m, const int n, const int k);

Expand Down
21 changes: 0 additions & 21 deletions cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_fp16_int4.cu

This file was deleted.

21 changes: 0 additions & 21 deletions cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_fp16_int8.cu

This file was deleted.

Loading

0 comments on commit 323dd4b

Please sign in to comment.