Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementations of Mamba 2 into FLA #34

Closed
DanFosing opened this issue Jul 20, 2024 · 7 comments
Closed

Add implementations of Mamba 2 into FLA #34

DanFosing opened this issue Jul 20, 2024 · 7 comments

Comments

@DanFosing
Copy link
Contributor

Hello! Do you plan to add Mamba 2 to your repo? If so, any estimate on when we can expect it?

@yzhangcs
Copy link
Member

@DanFosing Thanks for your attention. Yes, we indeed plan to add the kernels & models into fla. @sustcsonglin is playing with it, please stay tuned.

@yzhangcs yzhangcs changed the title Mamba 2 Add implementations of Mamba 2 into FLA Jul 21, 2024
@DanFosing
Copy link
Contributor Author

@yzhangcs I made a pull request #39 implementing mamba 2 (modeling.py file made for mamba codestral)

yzhangcs added a commit that referenced this issue Aug 5, 2024
Add implementations of Mamba 2 into FLA
@yzhangcs yzhangcs closed this as completed Aug 5, 2024
@yzhangcs
Copy link
Member

yzhangcs commented Aug 18, 2024

@DanFosing We finally have the simple GLA / Gated RetNet kernel, which is compatible with and significantly faster than Mamba2, thanks to the great job by you and @learning-chip (#39, #49 and #50).

@SmerkyG
Copy link

SmerkyG commented Aug 20, 2024

@yzhangcs Did the recent changes to simple GLA maybe break the backward()? The benchmark runs fine for me but when training I get RuntimeError: function SimpleGLAFunctionBackward returned an incorrect number of gradients (expected 7, got 6)

@sustcsonglin
Copy link
Collaborator

@yzhangcs Did the recent changes to simple GLA maybe break the backward()? The benchmark runs fine for me but when training I get RuntimeError: function SimpleGLAFunctionBackward returned an incorrect number of gradients (expected 7, got 6)

Thanks for bug catching! Just fixed

@SmerkyG
Copy link

SmerkyG commented Aug 20, 2024

Works great, thanks!!!

One more quick note: when using torch_simple_gla I get NaN after a few iterations... with chunk_simple_gla I also get NaN fairly quickly but it takes a while longer, around 325 iterations. Training with my simple implementation this doesn't happen.

simple implementation pseudocode:

def segsum(w_log): # B H L 1
    w_log_cumsum = torch.cumsum(w_log, dim=-2) # (B, H, L, 1)
    return torch.exp((w_log_cumsum - w_log_cumsum.mT).tril()).tril() # (B, H, L, L)

att = (q * q.size(-1)**-0.5) @ k.mT
att = att * segsum(w_log) # segsum handles zeroing the upper right tri
out = att @ v

Update: it works fine as long as I clamp the g (w_log) values to -5 or so... I guess you must be using the original GLA method to calculate this via relative changes to q,k so there's a precision limit.

@SmerkyG
Copy link

SmerkyG commented Aug 21, 2024

I also get an error with num_warps being an unrecognized argument when using torch.compile on chunk_simple_gla:

[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank0]:     return compiled_graph.compiled_artifact(inputs)
[rank0]:   File "/tmp/torchinductor_recursal/uz/cuz4byetjs4tgjziqso7bh6qlsoayncgrauaw7aei7y56e5t2re7.py", line 334, in call
[rank0]:     chunk_simple_gla_fwd_kernel_h_0.run(k=buf2, v=buf3, h=buf0, g=reinterpret_tensor(buf1, (32, 12, 512), (6144, 512, 1), 0), h0=arg3_1, ht=None, s_qk_h=32768, s_qk_t=64, s_qk_d=1, s_vo_h=32768, s_vo_t=64, s_vo_d=1, s_h_h=32768, s_h_t=64, T=512, K=64, V=64, BT=64, BK=64, BV=64, NT=8, USE_INITIAL_STATE=True, STORE_FINAL_STATE=False, num_warps=4, num_stages=1, grid=grid_wrapper_for_chunk_simple_gla_fwd_kernel_h_0, stream=stream0)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/triton_heuristics.py", line 670, in run
[rank0]:     return launcher(
[rank0]: TypeError: launcher() got an unexpected keyword argument 'num_warps'```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants