Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementations of Mamba 2 into FLA #39

Merged
merged 8 commits into from
Aug 5, 2024
Merged

Conversation

DanFosing
Copy link
Contributor

Modified version of https://github.com/huggingface/transformers/tree/add_codestral_mamba2/src/transformers/models/mamba2 (huggingface/transformers#32080) to work with FLA.
I haven't tested if it works yet, but I'm pretty sure it will work. It can probably be made a bit faster by implementing gated RMSNorm utilizing SiLU.

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

@DanFosing Great job! I'll make some tests soon. Thank you for the contributions.

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

@DanFosing
Copy link
Contributor Author

DanFosing commented Aug 5, 2024

I think it should work with fla now.

@DanFosing
Copy link
Contributor Author

Unfortunately because it's still an ongoing pull request into transformers package, there is a possibility that it may not fully work in some specific cases, that's why it requires some testing.

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

@DanFosing Hi, looks like there are still some errors

  File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 607, in forward
    return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "flash-linear-attention/fla/models/mamba2/modeling_mamba2.py", line 333, in cuda_kernels_forward
    rmsnorm_weight=self.norm.weight,
                   ^^^^^^^^^^^^^^^^
  File "anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1709, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GatedRMSNorm' object has no attribute 'weight'

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

@DanFosing You can check your impls via running

python -m benchmarks.benchmark_training_throughput --name mamba2

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

Also it would be better to beautify your code style via pre-commit to make it follow PEP8 guidelines.

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

You can refer to https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/layers/gla.py#L219-L228 for our impls of gate norm. I think there is no need to further wrap it with another GatedNorm.

@yzhangcs yzhangcs merged commit 0eacb4c into fla-org:main Aug 5, 2024
@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

@DanFosing Appreciate your hard work and quick response.

@DanFosing
Copy link
Contributor Author

Are you sure it works? When I tried it on kaggle I got assertion errors but it often has some weird problems so it may be an issue with kaggle (I can't test it on my pc right now)

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

Could you paste more detailed infos?

@yzhangcs
Copy link
Member

yzhangcs commented Aug 5, 2024

@DanFosing Here is the output

$ python -m benchmarks.benchmark_training_throughput --name mamba2
Initializing mamba2 model from the config:
Mamba2Config {
  "bos_token_id": 1,
  "chunk_size": 256,
  "conv_kernel": 4,
  "eos_token_id": 2,
  "expand": 2,
  "fuse_cross_entropy": true,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.1,
  "layer_norm_epsilon": 1e-05,
  "model_type": "mamba2",
  "n_groups": 8,
  "norm_before_gate": true,
  "num_heads": 64,
  "num_hidden_layers": 48,
  "pad_token_id": 0,
  "rescale_prenorm_residual": false,
  "residual_in_fp32": true,
  "rms_norm": true,
  "state_size": 128,
  "tie_word_embeddings": false,
  "time_step_floor": 0.0001,
  "time_step_limit": [
    0.0,
    Infinity
  ],
  "time_step_max": 0.1,
  "time_step_min": 0.001,
  "time_step_rank": 128,
  "transformers_version": "4.43.3",
  "use_bias": false,
  "use_cache": true,
  "use_conv_bias": true,
  "vocab_size": 32000
}

Mamba2ForCausalLM(
  (backbone): Mamba2Model(
    (embeddings): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-47): 48 x Mamba2Block(
        (norm): RMSNorm(2048, eps=1e-05)
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (in_proj): Linear(in_features=2048, out_features=10304, bias=False)
          (norm): FusedRMSNormSwishGate(4096, eps=1e-05)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
      )
    )
    (norm_f): RMSNorm(2048, eps=1e-05)
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
Number of parameters in total: 1548430336 (1.44GiB)
Allocated memory after initialization: 2.89GiB
Max memory allocated: 42.01GiB: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:11<00:00,  4.44s/it]
Thoughput:   32048.92 tokens/s: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:16<00:00,  1.96it/s]

@DanFosing
Copy link
Contributor Author

Turns out it was just a problem with T4 not supporting some things from causal_conv1d package, there is the same issue with mamba 1 and 2 from mamba-ssm. I don't remember what it was but there was some workaround that fixed it if you turned off some low memory mode in pytorch or something. Btw torch.compile support was added to mamba-1 but it requires some changes in mambacache and mamba modeling, I will try to implement it when I have some time.

@learning-chip
Copy link
Contributor

learning-chip commented Aug 5, 2024

One question on this PR:

From the FLA paper, both mamba-1 and mamba-2 can be written as Gated linear attention formulation:

fla_table

So, the SSM part of Mamba-2 (excluding conv1d, normalization, ...) should permit the same "Chunkwise Parallel Form" as advocated by the FLA paper, no?

chunkwise

Then, the modeling_mamba2.py implementation here shouldn't need to import mamba_chunk_scan_combined from original mamba_ssm repository, right? It can use a custom kernel similar to ops/gla/chunk_fuse.py in this repo.

Otherwise, the modeling_mamba2.py in current PR only uses custom RMSNorm/FusedRMSNormSwishGate kernels, while the SSM part is no different from original mamba_ssm repository. Then the performance will remain largely the same as original repo, and you cannot tell whether FLA formulation is faster...

@learning-chip
Copy link
Contributor

learning-chip commented Aug 5, 2024

The mamba-2 blog did mention that FLA chunkwise parallel "turns out to be essentially equivalent to the SSD algorithm specialized to a restricted case"

special_case

Will the FLA formulation be just identical to the mamba_chunk_scan_combined in original mamba-2 code? Or there is still some chance to improve on the original ver?

@DanFosing
Copy link
Contributor Author

Indeed I think it can be made faster if a kernel similar to chunk_fuse.py is used (maybe it would be possible to just modify gla one as both GLA and Mamba-2 are extremely similar to each other). Unfortunately I'm not familar with triton so I can't really do it, I think it would be best if you made an issue or something, unless some of FLA devs replies there. Current implementation is a lot like mamba-1 implementation, both are mostly like original mamba-ssm ones, with just some minor speed up thanks to FusedCrossEntropyLoss and custom RMSNorm and those are compatible with huggingface transformers.

@DanFosing
Copy link
Contributor Author

If you take a look at this paper: https://arxiv.org/pdf/2406.06484 you can also see that in terms of recurrence and memory read-out mamba-2 is very similar to RetNet and Linear Attention:
image

@learning-chip
Copy link
Contributor

learning-chip commented Aug 6, 2024

I think it would be best if you made an issue or something, unless some of FLA devs replies there.

Will look into it!

Also @yzhangcs for any suggestions -- like, which existing triton kernel is the best starting point to re-implement the SSM part of mamba2? From the table above, kernels for RetNet should be the most close ones?

The mamba-2 blog also mentions the close connection between RetNet and Mamba-2 (under their SSD framework):

Prior examples include the original linear attention as well as the recent Retentive Network (RetNet) model
[18] . These can be viewed as direct special cases of SSD.

@yzhangcs
Copy link
Member

yzhangcs commented Aug 6, 2024

@learning-chip Hi, you may refer to simple gla by @sustcsonglin which provides data-dependent decay upon RetNet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants