Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : initial Mamba-2 support #9126

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open

llama : initial Mamba-2 support #9126

wants to merge 25 commits into from

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented Aug 21, 2024

Follow-up from #8519 (comment). This should fix #7727 and fix #8519.

I've implemented the fully recurrent mode of Mamba-2, because it's very similar to Mamba-1, and also because it seems like the most appropriate mode for text generation.

This does not implement the sequentially semistructured matrix mode, because I'm not yet sure how the block decomposition would fit within the batch and ubatch framework of llama.cpp, and how the chunk size should be chosen. If the recurrent mode is faster at single-user auto-regressive text generation, then I'm not sure how to keep the graph node structure constant when using the most appropriate technique for the batch size.

If the sequentially semistructured matrix mode is eventually implemented, it should help with prompt processing speed for large prompts.

What to expect

(mostly taken from #8519 (comment))

The state in Mamba-2 is bigger than I thought; Mamba-Codestral-7B-v0.1 takes 263.5 MiB (in F32) per sequence (e.g. with -np 1), compared to 38 MiB (also in F32) for Falcon-Mamba-7B (which is based on Mamba-1). But that remains constant whatever the context size. Mamba-2 is easier to implement efficiently, so the bigger state does not really impede inference speed.

However, a big downside right now with recurrent models in llama.cpp is the lack of state rollback (which is implemented through state checkpoints in #7531, but needs to be re-adapted to #8526), so the prompt will be reprocessed a lot if using llama-server. I think using llama-cli in conversation mode does not have this problem, however (or maybe only the bare interactive mode with --in-prefix and --in-suffix, not sure).

This initial implementation is CPU-only, but uses SIMD for the SSM scan, so even though the state is bigger than for Mamba-1 models, in my tests, the speed of Mamba2-130M is similar or better than Mamba-130M (but still not that fast compared to transformer-based models with an empty context), when both are run on CPU.

The speed of Mamba-2 models seems comparable to Transformer-based models when the latter have 2k to 4k tokens in their context.

Summary of changes

  • Add support for Mamba2ForCausalLM (including the official Mamba-2 models, and Mamba-Codestral-7B-v0.1)
    • Note that config.json needs to contain "architectures": ["Mamba2ForCausalLM"], for the convert script to properly detect the architecture.
  • View Mamba-1 as having d_inner (aka 2 * n_embd) heads of size 1.
    • This simplifies the handling of shapes in ggml_ssm_scan
  • ggml
    • Implement Mamba-2's selective state update in ggml_ssm_scan.
      • Re-using the same operator as Mamba-1, because it's pretty much the same operation. (except for how ssm_a is broadcast)
    • Fuse the operation with ssm_d into ggml_ssm_scan
      • Otherwise it would need to be transposed, because the dot-products are done head-wise.
    • Implement Mamba-2's SSM scan with GGML_SIMD.
      • This is possible because there is no element-wise expf in the state update unlike with Mamba-1.
    • Avoid state copies for the SSM state (both for Mamba-1 and Mamba-2) by passing state ids to ggml_ssm_scan.
      • Mamba-2 states are huge. Otherwise masking and copying took close to 10% of the CPU time according to perf.

Other

Here's my favorite quote from Section 3.3 of https://arxiv.org/abs/2405.21060:

Furthermore—by a twist of fate—structured state space models and sequentially semiseparable matrices have the same acronyms, underscoring their equivalence! Conveniently we can use any of these acronyms SSM (state space model or semiseparable matrix), SSS (structured state space or sequentially semiseparable), or SS (state space or semiseparable) interchangeably to unambiguously refer to either concept.

TODO

  • Rebase onto master after merging llama : simplify Mamba with advanced batch splits #8526.
  • Avoid unnecessary moves of the state
  • Adapt the Metal kernels and the tests from ggml : add SSM Metal kernels #8546 to the updated ggml_ssm_scan
  • Remove the new GGML_MUL fast broadcast path because it's not used anymore to mask the states.
  • Maybe use a new metadata key instead of {arch}.ssm.time_step_rank for the number of heads of Mamba-2, because it's not really the rank of the time step (well, maybe kind of).
    • The meaning of the number of heads and the time-step rank is overlapping enough in Mamba-2 that I think this is fine.
  • Maybe not fuse the multiplication with ssm_d in ggml_ssm_scan?
  • Maybe split ggml_ssm_scan to separate the implementations for Mamba-1 and Mamba-2, although they do have a lot in common.
    • Seems like they can be distinguished easily enough at the time of kernel dispatch.

@compilade compilade marked this pull request as draft August 21, 2024 21:51
@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Aug 21, 2024
* ggml : improve ggml_mul speed when masking recurrent states
* ggml : make the ggml_mul fast broadcast path more consistently formatted
@compilade compilade changed the base branch from compilade/batch-splits to master August 21, 2024 22:02
@compilade compilade marked this pull request as ready for review August 21, 2024 22:02
@compilade compilade added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Aug 21, 2024
@ngxson
Copy link
Collaborator

ngxson commented Aug 22, 2024

Hey @compilade , thanks for implementing this!

I tried converting https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1 using convert_hf_to_gguf.py, but it gives error:

    with open(dir_model / "config.json", "r", encoding="utf-8") as f:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'config.json'

Nevertheless, I successfully converted a Mamba-Codestral transformers-compatible model: https://huggingface.co/Molbap/code2 (Need to comment out the line raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") in convert_hf_to_gguf.py)

Run it output model (remember to select the correct chat template, since the model does not come with one):

make llama-cli -j && ./llama-cli -m ../models/mcode-7.3B-Q8_0.gguf -cnv -p "You are a helpful assistant" --chat-template mistral -ngl 0

The result looks promising, but I have no idea why there are [UNK_BYTE_0x29681...]. It seems like the there is a problem with space character:

<<SYS>>Youareahelpfulassistant<</SYS>>
> hi
[UNK_BYTE_0xe29681▁Hello]Hello![UNK_BYTE_0xe29681▁How]How[UNK_BYTE_0xe29681▁can]can[UNK_BYTE_0xe29681▁I]I[UNK_BYTE_0xe29681▁assist]assist[UNK_BYTE_0xe29681▁you]you[UNK_BYTE_0xe29681▁today]today?

Link to download GGUF: https://huggingface.co/ngxson/codestral-mamba-llamacpp-test/tree/main

@compilade
Copy link
Collaborator Author

compilade commented Aug 22, 2024

Hey @compilade , thanks for implementing this!

I tried converting https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1 using convert_hf_to_gguf.py, but it gives error:

    with open(dir_model / "config.json", "r", encoding="utf-8") as f:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'config.json'

@ngxson

The steps I took to convert Mamba-Codestral-7B-v0.1 are the following:

  1. Rename consolidated.safetensors to model.safetensors
  2. Rename params.json to config.json
  3. Add the line "architectures": ["Mamba2ForCausalLM"], in config.json
  4. Rename tokenizer.model.v3 to tokenizer.model
  5. Use convert_hf_to_gguf.py as usual.

I did not have tokenization problems in my tests. Maybe because I was using the original SentencePiece tokenizer instead of a BPE tokenizer.

That tokenizer.json in the transformers-compatible version seems to have problematic spaces. It uses the SentencePiece space escaping instead of the BPE one. Its normalizer seems to revert the escaping, but that's not handled in llama.cpp.

There are probably still problems with the SentencePiece tokenizer too, like the lack of special tokens (control tokens seem to be identified correctly, the only difference seems to be with the 20 [REFERENCE_DOC_{n}] tokens (where n is 0 to 19), which tokenzier.json identifies as non-special added tokens (maps to USER_DEFINED for llama.cpp), while tokenizer.model identifies them as NORMAL tokens).

I think the SentencePiece tokenizer should be preferred for this model; it should be easier to handle without workarounds. I should change that in convert_hf_to_gguf.py. Meanwhile either not include tokenizer.json or rename it to something else.

The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.
@ngxson
Copy link
Collaborator

ngxson commented Aug 23, 2024

Thanks for the guide! I've successfully converted the original repository the gguf by following your steps.

For the transformers-compatible, I will try to contact the one who made it. Hopefully it will be fixed soon.

I'm wondering if convert_hf_to_gguf.py can automatically handle the renaming of params.json, consolidated.safetensors and tokenizer.model.v3? For now, my fear is that someone who use automated tools like gguf-my-repo will be stuck due to this issue.

(Also cc @Vaibhavs10 since he's the maintainer of gguf-my-repo.)

Copy link
Collaborator

@Vaibhavs10 Vaibhavs10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @compilade/ @ngxson - JFYI - the transformers weights are now merged in the main repo: https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1

If you face any issues with the conversion with this could you open an issue on the repo for us to track! 🤗

@isr431
Copy link

isr431 commented Aug 29, 2024

Any updates on when Codestral Mamba should be supported?

@learning-chip
Copy link

Nice work! Just a note on the ssm_scan kernel performance: a better fused implementation by the flash-linear-attention project can give the equivalent functionality as Mamba2's original kernel: fla-org/flash-linear-attention#49 , and runs 2x faster: fla-org/flash-linear-attention#50

@molbap
Copy link

molbap commented Sep 16, 2024

Hi @compilade ! I worked on repo conversion for the transformers-compatible mamba2 version, let us know if you need anything from us to move forward with this PR :)

@HanClinto
Copy link
Collaborator

I'm wondering if convert_hf_to_gguf.py can automatically handle the renaming of params.json, consolidated.safetensors and tokenizer.model.v3? For now, my fear is that someone who use automated tools like gguf-my-repo will be stuck due to this issue.

(Also cc @Vaibhavs10 since he's the maintainer of gguf-my-repo.)

It sounds like having a simple fallback of expected filenames would be a reasonable thing to include here? I don't know that we want to maintain a ton of different ones, but adding a second layer of fallbacks for alternate filenames doesn't feel arduous.

@compilade
Copy link
Collaborator Author

It sounds like having a simple fallback of expected filenames would be a reasonable thing to include here? I don't know that we want to maintain a ton of different ones, but adding a second layer of fallbacks for alternate filenames doesn't feel arduous.

@HanClinto

That's not really a problem anymore (at least for Mamba-Codestral) since the official repo was updated in https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1/commit/88085f9cdfa832c3aca8a0315a4520cf7558c947 to use more standard names.

What is currently blocking this is that the Metal and CUDA kernels for ggml_ssm_scan need to be updated BUT before that, I want to refactor the operator to completely avoid copying Mamba-2 states (because otherwise the unnecessary copies use a non-negligible fraction of the memory bandwidth (10% of total text generation inference time on my laptop), since Mamba-2 states are big).

@hg0428
Copy link

hg0428 commented Oct 1, 2024

Any updates on this?

@github-actions github-actions bot added the testing Everything test related label Oct 1, 2024
The max index is 31, so trimming the arguments is necessary.
Whoops, this is needed for the offset in the concatenated output.
This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.
This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
@gabe-l-hart gabe-l-hart mentioned this pull request Dec 12, 2024
3 tasks
@aallgeier
Copy link

Very excited for this PR! Thanks @compilade!!

@Tangshengku
Copy link

Hi @compilade , thank you for your impressive implementation. I am building the support for bi-mamba (https://arxiv.org/abs/2411.11843) on top of your implementation. However, I first tried to use mamba2-2.7 model and computed the ppl on wiki dataset (https://huggingface.co/datasets/ggml-org/ci/blob/main/wikitext-2-raw-v1.zip). The ppl is pretty bad with more than 3500+. So, have you ever tested the performance of your implementation before? My script is:
./llama-perplexity -m ./ckpt/mamba2-2.7B -f ./wikitext-2-raw/wiki.test.raw --n-gpu-layers 0

@compilade
Copy link
Collaborator Author

@Tangshengku

Bi-Mamba seems amazing!

The ppl is pretty bad with more than 3500+. So, have you ever tested the performance of your implementation before?

I did test it when working on it, and it did work, but it has been a while. I will get back to this to find out what has broken.

@EthanFS
Copy link

EthanFS commented Feb 26, 2025

@Tangshengku

Bi-Mamba seems amazing!

The ppl is pretty bad with more than 3500+. So, have you ever tested the performance of your implementation before?

I did test it when working on it, and it did work, but it has been a while. I will get back to this to find out what has broken.

Sounds like the issue might be related to state rollback which @compilade previously mentioned,

However, a big downside right now with recurrent models in llama.cpp is the lack of state rollback (which is implemented through state checkpoints in #7531, but needs to be re-adapted to #8526), so the prompt will be reprocessed a lot if using llama-server.

This could be causing the high perplexity values since the model has to reprocess previous content and repeat with each generation. It is also the problem I am encountering now; I cannot directly use llama.cpp to evaluate accuracy because model does not generate EOS.

@github-actions github-actions bot added the Apple Metal https://en.wikipedia.org/wiki/Metal_(API) label Feb 26, 2025
@compilade
Copy link
Collaborator Author

compilade commented Feb 26, 2025

However, I first tried to use mamba2-2.7 model and computed the ppl on wiki dataset

@Tangshengku Which model exactly is causing you problems? I can't reproduce the problem with a freshly-converted mamba2-370m (and first adding "architectures": ["Mamba2ForCausalLM"], in config.json).

Perplexity seems fine (on 8 chunks of wiki.test.raw):

$ ./bin/llama-perplexity -m /path/to/mamba2-370M-Q8_0.gguf -f /path/to/wikitext-2-raw/wiki.test.raw --chunks 8
...
[1]10.6303,[2]13.3056,[3]14.2217,[4]14.3134,[5]13.8906,[6]13.9622,[7]14.4598,[8]14.9276,
Final estimate: PPL = 14.9276 +/- 0.92728

I've also tried with mamba2-2.7b (which seems to be the model you were referring to?), and I still can't reproduce the problem.

$ ./bin/llama-perplexity -m /path/to/mamba2-2.7B-Q4_K_M.gguf -f /path/to/wikitext-2-raw/wiki.test.raw --chunks 8
...
[1]7.4471,[2]8.9320,[3]9.1891,[4]9.3914,[5]9.2716,[6]9.4961,[7]9.8894,[8]10.2744,
Final estimate: PPL = 10.2744 +/- 0.59925

I'm not sure what could be causing what you've seen. Note that I tested mamba-370m with the latest commit before your message in addition to the latest commit. Both gave the same results.

I would suggest to try re-converting the model with convert_hf_to_gguf.py and testing its perplexity again.

For example, assuming this is run from a checkout of this branch, this results in a F16 model at /somewhere/tmp/mamba2-2.7b-F16.gguf assuming mamba2-2.7b is in /somewhere/src/mamba2-2.7b/ with its config.json modified with the additional line "architectures": ["Mamba2ForCausalLM"], so that convert_hf_to_gguf.py can know it's a Mamba-2 model.

$ python3 convert_hf_to_gguf.py --outtype f16 --outfile /somewhere/tmp/mamba2-2.7b-F16.gguf /somewhere/src/mamba2-2.7b/
$ ./build/bin/llama-quantize /somewhere/tmp/mamba2-2.7b-{F16,Q4_K_M}.gguf q4_k_m

@Tangshengku Alternatively, the problem may be related to GPU support... All my tests were on a CPU-only build (with AVX and AVX2).
I did not yet test when building with this CUDA (although I think it should properly fallback to CPU?).

If you find out how to fix the problem you've noticed, please do share.

Looking forward to help you make Bi-Mamba work with llama.cpp.


Sounds like the issue might be related to state rollback

@EthanFS State rollback not being properly handled shouldn't affect perplexity; it's only relevant when partially removing tokens from the context (as opposed to clearing the context, which is handled properly).

Partial removal with recurrent models currently is handled by recomputing the context from the beginning if I recall correctly. That should not affect perplexity, only the efficiency when rolling back.

@EthanFS
Copy link

EthanFS commented Feb 26, 2025

@compilade Thanks for your explanation.
When using unofficial hf mamba-2 model or official hf mamba-1 model, I face a problem that the model cannot generate the EOS properly, as shown below, the model will continue to repeat. Do you have this problem or is it that I have set something wrong?

./llama-cli -m ../../../../mamba-130m-hf/mamba-130M-hf-F32.gguf -n 1024 -p "### Question: What is quantum computing?\n### Answer:"
build: 4769 (34a846b5) with cc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 for aarch64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_loader: loaded meta data with 28 key-value pairs and 242 tensors from ../../../../mamba-130m-hf/mamba-130M-hf-F32.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = mamba
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Mamba 130m Hf
llama_model_loader: - kv   3:                           general.finetune str              = hf
llama_model_loader: - kv   4:                           general.basename str              = mamba
llama_model_loader: - kv   5:                         general.size_label str              = 130M
llama_model_loader: - kv   6:                       mamba.context_length u32              = 1048576
llama_model_loader: - kv   7:                     mamba.embedding_length u32              = 768
llama_model_loader: - kv   8:                  mamba.feed_forward_length u32              = 0
llama_model_loader: - kv   9:                 mamba.attention.head_count u32              = 0
llama_model_loader: - kv  10:                          mamba.block_count u32              = 24
llama_model_loader: - kv  11:                      mamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  12:                       mamba.ssm.inner_size u32              = 1536
llama_model_loader: - kv  13:                       mamba.ssm.state_size u32              = 16
llama_model_loader: - kv  14:                   mamba.ssm.time_step_rank u32              = 48
llama_model_loader: - kv  15:     mamba.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  16:                       mamba.ssm.dt_b_c_rms bool             = false
llama_model_loader: - kv  17:                          general.file_type u32              = 0
llama_model_loader: - kv  18:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  19:                         tokenizer.ggml.pre str              = olmo
llama_model_loader: - kv  20:                      tokenizer.ggml.tokens arr[str,50280]   = ["<|endoftext|>", "<|padding|>", "!",...
llama_model_loader: - kv  21:                  tokenizer.ggml.token_type arr[i32,50280]   = [3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  22:                      tokenizer.ggml.merges arr[str,50009]   = ["Ġ Ġ", "Ġ t", "Ġ a", "h e", "i n...
llama_model_loader: - kv  23:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  24:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  25:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  26:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  242 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = all F32
print_info: file size   = 492.61 MiB (32.00 BPW) 
load: special tokens cache size = 25
load: token to piece cache size = 0.2984 MB
print_info: arch             = mamba
print_info: vocab_only       = 0
print_info: n_ctx_train      = 1048576
print_info: n_embd           = 768
print_info: n_layer          = 24
print_info: n_head           = 0
print_info: n_head_kv        = 0
print_info: n_rot            = 0
print_info: n_swa            = 0
print_info: n_embd_head_k    = 0
print_info: n_embd_head_v    = 0
print_info: n_gqa            = 0
print_info: n_embd_k_gqa     = 0
print_info: n_embd_v_gqa     = 0
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-05
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: n_ff             = 0
print_info: n_expert         = 0
print_info: n_expert_used    = 0
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = -1
print_info: rope scaling     = linear
print_info: freq_base_train  = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 1048576
print_info: rope_finetuned   = unknown
print_info: ssm_d_conv       = 4
print_info: ssm_d_inner      = 1536
print_info: ssm_d_state      = 16
print_info: ssm_dt_rank      = 48
print_info: ssm_dt_b_c_rms   = 0
print_info: model type       = 0.1B
print_info: model params     = 129.14 M
print_info: general.name     = Mamba 130m Hf
print_info: vocab type       = BPE
print_info: n_vocab          = 50280
print_info: n_merges         = 50009
print_info: BOS token        = 0 '<|endoftext|>'
print_info: EOS token        = 0 '<|endoftext|>'
print_info: EOT token        = 0 '<|endoftext|>'
print_info: UNK token        = 0 '<|endoftext|>'
print_info: PAD token        = 0 '<|endoftext|>'
print_info: LF token         = 187 'Ċ'
print_info: EOG token        = 0 '<|endoftext|>'
print_info: max token length = 1024
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors:   CPU_Mapped model buffer size =   492.61 MiB
....................................................
llama_init_from_model: n_seq_max     = 1
llama_init_from_model: n_ctx         = 4096
llama_init_from_model: n_ctx_per_seq = 4096
llama_init_from_model: n_batch       = 2048
llama_init_from_model: n_ubatch      = 512
llama_init_from_model: flash_attn    = 0
llama_init_from_model: freq_base     = 10000.0
llama_init_from_model: freq_scale    = 1
llama_init_from_model: n_ctx_per_seq (4096) < n_ctx_train (1048576) -- the full capacity of the model will not be utilized
llama_kv_cache_init: kv_size = 1, offload = 1, type_k = 'f32', type_v = 'f32', n_layer = 24, can_shift = 0
llama_kv_cache_init:        CPU KV buffer size =     2.67 MiB
llama_init_from_model: KV self size  =    2.67 MiB, K (f32):    0.42 MiB, V (f32):    2.25 MiB
llama_init_from_model:        CPU  output buffer size =     0.19 MiB
llama_init_from_model:        CPU compute buffer size =   102.82 MiB
llama_init_from_model: graph nodes  = 1182
llama_init_from_model: graph splits = 1
common_init_from_params: KV cache shifting is not supported for this model, disabling KV cache shifting
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 8

system_info: n_threads = 8 (n_threads_batch = 8) / 8 | CPU : NEON = 1 | ARM_FMA = 1 | FP16_VA = 1 | DOTPROD = 1 | LLAMAFILE = 1 | OPENMP = 1 | AARCH64_REPACK = 1 | 

sampler seed: 1747808005
sampler params: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
	top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist 
generate: n_ctx = 4096, n_batch = 2048, n_predict = 1024, n_keep = 0

### Question: What is quantum computing?
### Answer: It is the ability to control the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is a concept that describes the nature of quantum states, and how to describe them.

### Question: What is quantum computing?
### Answer: It is an algorithm that is based on the use of quantum measurements.

### Question: What is quantum computation?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computation?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computation?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the speed of information transfer.

### Question: What is quantum computing?
### Answer: It is the concept that describes the nature of the

llama_perf_sampler_print:    sampling time =      84.11 ms /  1036 runs   (    0.08 ms per token, 12317.06 tokens per second)
llama_perf_context_print:        load time =     275.08 ms
llama_perf_context_print: prompt eval time =     120.83 ms /    12 tokens (   10.07 ms per token,    99.31 tokens per second)
llama_perf_context_print:        eval time =   37935.34 ms /  1023 runs   (   37.08 ms per token,    26.97 tokens per second)
llama_perf_context_print:       total time =   38279.47 ms /  1035 tokens

@compilade
Copy link
Collaborator Author

compilade commented Feb 27, 2025

@EthanFS
I don't think these small Mamba(1 and 2) models are instruction-trained, and so I wouldn't expect them to ever really "finish" their output (although there are cases where they do output the EOS token. From my small tests it seems like when using a repetition penalty (e.g. --repeat-penalty 1.1), they are more likely to end their output at some point).

What I would suggest would be to either use repetition penalty, or use a stop string, or other models which are instruction-trained.

For example, to use a stop string, you can use -r "Your stop string":

$ ./bin/llama-cli -m /path/to/mamba2-370m-Q8_0.gguf -n 1024 -p "### Question: What is quantum computing?\n### Answer:" -r "### Question:"

I hope this helps!


@Tangshengku
Good news, I've looked at https://github.com/Tangshengku/Bi-Mamba, and tried to make Bi-Mamba work, and had to make this small script to properly convert the 2.7B Bi-Mamba model:

Script to prepare the `pytorch_model.bin` of Bi-Mamba into a proper `model.safetensors` for conversion (click to expand)
import torch
from safetensors.torch import save_file

model = torch.load("pytorch_model.bin", map_location="cpu", weights_only=True, mmap=True)

new_model = {}

for name, data in model.items():
    if ".in_proj." in name or ".out_proj." in name:
        if name.endswith(".weight"):
            prefix = name.removesuffix(".weight")
            wscale = model[prefix + ".wscale"]
            wbias = model[prefix + ".wbias"]
            data = wscale * torch.sign(data) + wbias
        else:
            continue
    new_model[name] = data
    print(name, data.shape)

save_file(new_model, "model.safetensors")

NOTE: This takes around 20GiB of free RAM to run with the 10GiB F32 Bi-Mamba 2.7B model

The important bit is that only the sign of the weights is used along with the scale and bias. Otherwise the model this produces does not have good perplexity.


And so with your model (which is using the Mamba-2 architecture) I get good perplexity results:

$ ./bin/llama-perplexity -m /path/to/bimamba-2.7B-F16.gguf -f /path/to/wikitext-2-raw/wiki.test.raw --chunks 8
...
llama_model_loader: - kv   0:                       general.architecture str              = mamba2
...
[1]7.9030,[2]9.0506,[3]9.5904,[4]10.9340,[5]10.9741,[6]10.7500,[7]11.0106,[8]10.9961,
Final estimate: PPL = 10.9961 +/- 0.65000

Unfortunately, it seems like TQ1_0 and TQ2_0 are not a good fit for this model, because these types don't have an offset, only a scale.
But it will definitely be possible to make new types.

From having made TQ2_0 (in #8151), I'm pretty sure there would be no (computational) overhead in adding a f16 offset, since the operation is already partly done. It would even help with alignment. And it's already a fast type. It's a good starting point if you want to experiment with this.

A 1-bit type with 256-element blocks with a f16 scale and a f16 offset would use 1.125 bits per weight.

@Tangshengku
Copy link

@compilade Hi, thank you for your quick reply! Sorry, I found that my issue is that I accidentally used the tokenizer of llama2 instead of the tokenizer used in the original mamba. After using the correct tokenizer, I can replicate the exact ppl results you provided in both Mamba-2.7B and Bi-Mamba on M4 Pro CPU.

Instead of computing the w_scale and w_bias during tensor transformation, I compute the w_scale and w_bias during inference on the activation, which is equivalent to the operation on the binarized weight in math, like this:

 // in function llm_build_mamba2()
 ....
 struct ggml_tensor * cur_scale = ggml_mul(ctx, cur, model.layers[il].ssm_in_wscale);
 struct ggml_tensor * bias_term = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in_wbias, cur);
 // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
 struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur_scale);
 zxBCdt = ggml_add(ctx, zxBCdt, bias_term);
 ....
 ....
 struct ggml_tensor * y_scale = ggml_mul(ctx, y, model.layers[il].ssm_out_wscale);
 struct ggml_tensor * bias_term_out = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out_wbias, y); 
 cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y_scale); 
 cur = ggml_add(ctx, cur, bias_term_out);

In fact, I am quite new to llama.cpp, but I guess this operation on activation could be beneficial for actual binary computation (hope so...). If not, we can start over. I am happy to contribute to the later development for the new type and even the GPU support (just need more time to get familiar with this codebase) :).


@EthanFS Hi, thanks for the help. I agree with @compilade that the model shows non-stop pattern since the models are not instruction-tuned. You can check the generation case in the Figure 10, 11 and 12 in the appendix of our paper https://arxiv.org/pdf/2411.11843

@compilade
Copy link
Collaborator Author

Instead of computing the w_scale and w_bias during tensor transformation, I compute the w_scale and w_bias during inference on the activation, which is equivalent to the operation on the binarized weight in math

@Tangshengku

Yes, this is in line with the eventual goal of making an appropriate quantization type for binary models. The scale and bias would be applied at runtime, during matmul, without necessarily having to cast it all to F16 before the matmul. That would allow using appropriate integer SIMD on CPU too.

With such a binary type, changing the model graphs would not be necessary, which means it would work for Bi-Mamba, and also FBI-LLM and any other binarized model based on a supported model architecture.

In fact, I am quite new to llama.cpp, but I guess this operation on activation could be beneficial for actual binary computation (hope so...). If not, we can start over.

This operation would help, and that's why something similar is built-in to most quantization types in ggml, so that it can be accelerated by the appropriate backend without having to modify the model graph each time a new quant type is introduced.

I am happy to contribute to the later development for the new type and even the GPU support (just need more time to get familiar with this codebase) :).

We all have to start somewhere :)
The best way to get familiar with the codebase of llama.cpp is to try to implement something. The easiest is when there's already something similar implemented.

If you want, I can start making a prototype for a binary type in ggml, but I encourage you to give it a try.

There's also something interesting with binary weights with a scale and bias, because the ideal rounding for {-1, 1} with a scale and without a bias is always the sign of the weights (which means imatrix would not do much), while when adding a bias (like in Bi-Mamba and FBI-LLM), an exhaustive search for the ideal rounding is possible, which means imatrix could work with such a type (which means it might be usable in non-binarized models), and it is similar to an exhaustive search with ternary quantization without a bias1, although the search space has different symmetries (e.g. for 4D vectors, scale and bias make the representable binary vectors form a 3D rhombic dodecahedron)

It's impressive that Mamba-2 with binarized weights can work (as you've shown with Bi-Mamba). At some point, the states will take more memory than the weights. I wonder how that would affect speed.

Footnotes

  1. based on some experiments I've tried in https://github.com/compilade/rounding-experiments, both can rely on sorting the vector components and cumulative sums for an exhaustive search.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Support Codestral Mamba llama : support Mamba-2