Skip to content

Commit

Permalink
Fix quant for v7 & Add q4_k/q5_k quants support
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <[email protected]>
  • Loading branch information
MollySophia committed Jan 9, 2025
1 parent 49c4754 commit 2818c44
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 8 deletions.
2 changes: 2 additions & 0 deletions extras/quantize.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ bool QueryPerformanceCounter(uint64_t* lpPerformanceCount);
static enum ggml_type type_from_string(const char * string) {
if (strcmp(string, "Q4_0") == 0) return GGML_TYPE_Q4_0;
if (strcmp(string, "Q4_1") == 0) return GGML_TYPE_Q4_1;
if (strcmp(string, "Q4_K") == 0) return GGML_TYPE_Q4_K;
if (strcmp(string, "Q5_0") == 0) return GGML_TYPE_Q5_0;
if (strcmp(string, "Q5_1") == 0) return GGML_TYPE_Q5_1;
if (strcmp(string, "Q5_K") == 0) return GGML_TYPE_Q5_K;
if (strcmp(string, "Q8_0") == 0) return GGML_TYPE_Q8_0;
return GGML_TYPE_COUNT;
}
Expand Down
2 changes: 2 additions & 0 deletions python/rwkv_cpp/rwkv_cpp_shared_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
QUANTIZED_FORMAT_NAMES: Tuple[str, str, str, str, str] = (
'Q4_0',
'Q4_1',
'Q4_K',
'Q5_0',
'Q5_1',
'Q5_K',
'Q8_0'
)

Expand Down
32 changes: 28 additions & 4 deletions rwkv_file_format.inc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ enum rwkv_type {
TYPE_Q5_0,
TYPE_Q5_1,
TYPE_Q8_0,
TYPE_Q8_1,
TYPE_Q2_K,
TYPE_Q3_K,
TYPE_Q4_K,
TYPE_Q5_K,
TYPE_Q6_K,
TYPE_Q8_K,
TYPE_COUNT
};

Expand All @@ -29,6 +36,13 @@ static const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = {
GGML_TYPE_Q5_0, /* Q5_0 */
GGML_TYPE_Q5_1, /* Q5_1 */
GGML_TYPE_Q8_0, /* Q8_0 */
GGML_TYPE_Q8_1, /* Q8_1 */
GGML_TYPE_Q2_K, /* Q2_K */
GGML_TYPE_Q3_K, /* Q3_K */
GGML_TYPE_Q4_K, /* Q4_K */
GGML_TYPE_Q5_K, /* Q5_K */
GGML_TYPE_Q6_K, /* Q6_K */
GGML_TYPE_Q8_K, /* Q8_K */
GGML_TYPE_COUNT /* COUNT */
};

Expand All @@ -42,10 +56,13 @@ static const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = {
TYPE_Q5_0, /* Q5_0 */
TYPE_Q5_1, /* Q5_1 */
TYPE_Q8_0, /* Q8_0 */
TYPE_COUNT, /* Q8_1 */
TYPE_COUNT, /* I8 */
TYPE_COUNT, /* I16 */
TYPE_COUNT, /* I32 */
TYPE_Q8_1, /* Q8_1 */
TYPE_Q2_K, /* Q2_K */
TYPE_Q3_K, /* Q3_K */
TYPE_Q4_K, /* Q4_K */
TYPE_Q5_K, /* Q5_K */
TYPE_Q6_K, /* Q6_K */
TYPE_Q8_K, /* Q8_K */
TYPE_COUNT, /* COUNT */
};

Expand All @@ -60,6 +77,13 @@ static const char * rwkv_type_to_string[TYPE_COUNT + 1] = {
"Q5_0",
"Q5_1",
"Q8_0",
"Q8_1",
"Q2_K",
"Q3_K",
"Q4_K",
"Q5_K",
"Q6_K",
"Q8_K",
"unknown"
};

Expand Down
21 changes: 17 additions & 4 deletions rwkv_quantize.inc
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
static bool rwkv_tensor_needs_quant(std::string name) {
return name != "emb.weight" &&
name != "head.weight" &&
name.find("att.v1") == std::string::npos &&
name.find("att.v2") == std::string::npos &&
name.find("att.g1") == std::string::npos &&
name.find("att.g2") == std::string::npos &&
name.find("att.a1") == std::string::npos &&
name.find("att.a2") == std::string::npos &&
name.find("att.w1") == std::string::npos &&
name.find("att.w2") == std::string::npos &&
name.find("att.r_k") == std::string::npos;
}

// API function.
bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) {
global_last_error = RWKV_ERROR_NONE;
Expand Down Expand Up @@ -122,10 +136,9 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const
// In RWKV v5, time_decay and time_first/time_faaaa are 3D tensors, so they are not quantized.
if ((header.data_type == TYPE_FP32 || header.data_type == TYPE_FP16) &&
header.dim_count == 2 &&
name != "emb.weight" &&
name != "head.weight"
rwkv_tensor_needs_quant(name)
) {
RWKV_MSG("quantizing... ");
RWKV_MSG("-> %6s ", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]);

size_t nelements = (size_t) header.size0 * (size_t) header.size1 * (size_t) header.size2;

Expand All @@ -137,7 +150,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const
header.data_type = rwkv_type_from_ggml[out_type];
data = out_buf;

RWKV_MSG("size = %8.2f MB -> %8.2f MB | hist: ", orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
RWKV_MSG("size = %8.2f MB -> %8.2f MB", orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);

RWKV_MSG("\n");
} else {
Expand Down

0 comments on commit 2818c44

Please sign in to comment.