-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add implementations of Mamba 2 into FLA #34
Comments
@DanFosing Thanks for your attention. Yes, we indeed plan to add the kernels & models into |
Add implementations of Mamba 2 into FLA
@DanFosing We finally have the simple GLA / Gated RetNet kernel, which is compatible with and significantly faster than Mamba2, thanks to the great job by you and @learning-chip (#39, #49 and #50). |
@yzhangcs Did the recent changes to simple GLA maybe break the backward()? The benchmark runs fine for me but when training I get |
Thanks for bug catching! Just fixed |
Works great, thanks!!! One more quick note: when using torch_simple_gla I get NaN after a few iterations... with chunk_simple_gla I also get NaN fairly quickly but it takes a while longer, around 325 iterations. Training with my simple implementation this doesn't happen. simple implementation pseudocode:
Update: it works fine as long as I clamp the g (w_log) values to -5 or so... I guess you must be using the original GLA method to calculate this via relative changes to q,k so there's a precision limit. |
I also get an error with num_warps being an unrecognized argument when using torch.compile on chunk_simple_gla:
|
Hello! Do you plan to add Mamba 2 to your repo? If so, any estimate on when we can expect it?
The text was updated successfully, but these errors were encountered: