-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
R1-Style distributed GRPO #2326
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2326
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2ba4a97 with merge base e6cba25 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Im extremely interested in this. |
Reorganize some recipes Add SFT dataset
SFT recipe for GSM8k
Clean up reward function New (untested) generation function New recipe config
@RedTachyon thanks for the implementation! Please let me know if there's anything you'd like support on, I'm happy to help! For my own use, could you please share how you set up the environment/commands to launch training? |
Manual resharding (?) Mostly-working 3B GRPO config SFT recipe for gsm8k
Hi everyone, glad to see people interested in contributing to this implementation! I'm happy to say that the core implementation "just works" - on a recent run I did the following process:
At the moment, everything follows the R1 paper relatively closely - format is cot 42, and reward computation is done by XML parsing the response, and checking the answer within tags for a perfect match (failed parse = 0 reward). There's some Base model by itself struggles with the format (and knowing when to output <|eos|>), so its performance is just bad. The SFT-trained model follows the format well, but kinda sucks at math, so it tends to get about 10% success rate on GSM8k. Continuing the training with GRPO, the model climbs up to a ~60% success rate! (and can probably go higher if it keeps running for longer) ![]() Caveats: this estimate is based on the training set questions, without a separate eval - but a majority of the improvement happens before the first epoch finishes. Which brings me to the list of things that still need to be done. I'll probably move it to the top comment on the PR later, but for now here it is:
I'll start going through some of these things, and if someone wants to contribute, please mention it so that we don't duplicate the effort. Later today I'll also push my sbatch files to run the full pipeline so that anyone can run it on slurm - for single-node experiments, regular torchtune runner should work fine. Regarding the organization: it's probably best to keep everything under a shared PR (i.e. this), and make other intermediate PRs into this branch. So various people can contribute parts of the GRPO setup, we merge it into this branch, and then when it's all done, the maintainers can review the complete thing and merge it into main. |
@ebsmothers @SalmanMohammadi So I did a refactor and moved all "weird" stuff into /dev. Notably, I also took my changes to the |
@RedTachyon thanks for doing that! At least at first glance this looks much easier for us to land, will give a proper review tomorrow. |
@ebsmothers Gentle nudge - I'm hoping to have this merged sooner rather than later, so that I can start iterating on some more experimental stuff without getting into a branching nightmare |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience @RedTachyon! A few more small comments and I think this is good to go (now that everything is in dev I will not be as pedantic about some of the design considerations). Main thing is to make sure that everything is runnable out-of-the-box. Also if you're able to share some of the logged metrics (successes, rewards, etc) on your latest runs that would be great as well.
self._log_peak_memory_stats = False | ||
|
||
self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False) | ||
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bumping this comment, I don't believe this is actually used anywhere? If so can we just remove it?
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
hey @RedTachyon , just wanted to thank you for being so responsive and putting the effort into this. We all appreciate it :) |
Happy to contribute, and thanks for all the help in bringing this to a publishable state! I applied the final fixes and launched another run to make sure it still learns - so far it's just about the same, and I don't really expect anything to be different. (EDIT: about 2 hours in, it's going up the same way it was going up before, so it's most likely all good) As for some existing metrics - right now I don't have anything very pretty, but just to give a sense of what to expect, here are some wandb graphs (which I unfortunately can't share as actual wandb reports at the moment). There are 8 curves, varying across training from the base model or from an SFT-initialized model, and then across the number of nodes (so effectively batch size - 2 nodes have an effective batch size of 16, 4 nodes - 32, 8 nodes - 64). Note that the hardware isn't super consistent, so the time graphs are meant as a very rough estimate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much @RedTachyon! This was a serious PR. Really appreciate your patience through the review process and we're so glad to have GRPO thanks to your efforts!
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: salman <[email protected]>
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: salman <[email protected]>
Context
What is the purpose of this PR? Is it to
After some discussions on another PR and on Discord, this is the current state of my distributed GRPO implementation. I'm still iterating on this, prioritizing checking whether it actually works.
I have some early successes, but it's too soon to proclaim victory. Soon I'll probably also adapt it to a multinode workflow when #2301 is merged (or just snatch some code from there), because RL is sufficiently resource-hungry that single-node training isn't really an option for anything even moderately serious.
Right now the repo/PR is very messy and in a researchy state, to find something that works. Once it does work, I'll start cleaning it up to meet OSS standards. I'm putting it here to be able to keep track of the diffs and for potential discussions.
The rest will be filled when possible/relevant:
Changelog
What are the changes made in this PR?
*
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example