-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add implementations of Mamba 2 into FLA #39
Conversation
@DanFosing Great job! I'll make some tests soon. Thank you for the contributions. |
Could you add |
I think it should work with fla now. |
Unfortunately because it's still an ongoing pull request into transformers package, there is a possibility that it may not fully work in some specific cases, that's why it requires some testing. |
@DanFosing Hi, looks like there are still some errors File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 607, in forward
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 333, in cuda_kernels_forward
rmsnorm_weight=self.norm.weight,
^^^^^^^^^^^^^^^^
File "anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1709, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GatedRMSNorm' object has no attribute 'weight' |
@DanFosing You can check your impls via running python -m benchmarks.benchmark_training_throughput --name mamba2 |
Also it would be better to beautify your code style via |
You can refer to https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/layers/gla.py#L219-L228 for our impls of gate norm. I think there is no need to further wrap it with another |
@DanFosing Appreciate your hard work and quick response. |
Are you sure it works? When I tried it on kaggle I got assertion errors but it often has some weird problems so it may be an issue with kaggle (I can't test it on my pc right now) |
Could you paste more detailed infos? |
@DanFosing Here is the output $ python -m benchmarks.benchmark_training_throughput --name mamba2
Initializing mamba2 model from the config:
Mamba2Config {
"bos_token_id": 1,
"chunk_size": 256,
"conv_kernel": 4,
"eos_token_id": 2,
"expand": 2,
"fuse_cross_entropy": true,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.1,
"layer_norm_epsilon": 1e-05,
"model_type": "mamba2",
"n_groups": 8,
"norm_before_gate": true,
"num_heads": 64,
"num_hidden_layers": 48,
"pad_token_id": 0,
"rescale_prenorm_residual": false,
"residual_in_fp32": true,
"rms_norm": true,
"state_size": 128,
"tie_word_embeddings": false,
"time_step_floor": 0.0001,
"time_step_limit": [
0.0,
Infinity
],
"time_step_max": 0.1,
"time_step_min": 0.001,
"time_step_rank": 128,
"transformers_version": "4.43.3",
"use_bias": false,
"use_cache": true,
"use_conv_bias": true,
"vocab_size": 32000
}
Mamba2ForCausalLM(
(backbone): Mamba2Model(
(embeddings): Embedding(32000, 2048)
(layers): ModuleList(
(0-47): 48 x Mamba2Block(
(norm): RMSNorm(2048, eps=1e-05)
(mixer): Mamba2Mixer(
(act): SiLU()
(conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
(in_proj): Linear(in_features=2048, out_features=10304, bias=False)
(norm): FusedRMSNormSwishGate(4096, eps=1e-05)
(out_proj): Linear(in_features=4096, out_features=2048, bias=False)
)
)
)
(norm_f): RMSNorm(2048, eps=1e-05)
)
(lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
Number of parameters in total: 1548430336 (1.44GiB)
Allocated memory after initialization: 2.89GiB
Max memory allocated: 42.01GiB: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:11<00:00, 4.44s/it]
Thoughput: 32048.92 tokens/s: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:16<00:00, 1.96it/s] |
Turns out it was just a problem with T4 not supporting some things from causal_conv1d package, there is the same issue with mamba 1 and 2 from mamba-ssm. I don't remember what it was but there was some workaround that fixed it if you turned off some low memory mode in pytorch or something. Btw torch.compile support was added to mamba-1 but it requires some changes in mambacache and mamba modeling, I will try to implement it when I have some time. |
One question on this PR: From the FLA paper, both mamba-1 and mamba-2 can be written as Gated linear attention formulation: ![]() So, the SSM part of Mamba-2 (excluding conv1d, normalization, ...) should permit the same "Chunkwise Parallel Form" as advocated by the FLA paper, no? ![]() Then, the Otherwise, the |
The mamba-2 blog did mention that FLA chunkwise parallel "turns out to be essentially equivalent to the SSD algorithm specialized to a restricted case" ![]() Will the FLA formulation be just identical to the |
Indeed I think it can be made faster if a kernel similar to chunk_fuse.py is used (maybe it would be possible to just modify gla one as both GLA and Mamba-2 are extremely similar to each other). Unfortunately I'm not familar with triton so I can't really do it, I think it would be best if you made an issue or something, unless some of FLA devs replies there. Current implementation is a lot like mamba-1 implementation, both are mostly like original mamba-ssm ones, with just some minor speed up thanks to FusedCrossEntropyLoss and custom RMSNorm and those are compatible with huggingface transformers. |
If you take a look at this paper: https://arxiv.org/pdf/2406.06484 you can also see that in terms of recurrence and memory read-out mamba-2 is very similar to RetNet and Linear Attention: |
Will look into it! Also @yzhangcs for any suggestions -- like, which existing triton kernel is the best starting point to re-implement the SSM part of mamba2? From the table above, kernels for RetNet should be the most close ones? The mamba-2 blog also mentions the close connection between RetNet and Mamba-2 (under their SSD framework):
|
@learning-chip Hi, you may refer to simple gla by @sustcsonglin which provides data-dependent decay upon RetNet. |
Modified version of https://github.com/huggingface/transformers/tree/add_codestral_mamba2/src/transformers/models/mamba2 (huggingface/transformers#32080) to work with FLA.
I haven't tested if it works yet, but I'm pretty sure it will work. It can probably be made a bit faster by implementing gated RMSNorm utilizing SiLU.