Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R1-Style distributed GRPO #2326

Merged
merged 117 commits into from
Feb 21, 2025
Merged

R1-Style distributed GRPO #2326

merged 117 commits into from
Feb 21, 2025

Conversation

RedTachyon
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

After some discussions on another PR and on Discord, this is the current state of my distributed GRPO implementation. I'm still iterating on this, prioritizing checking whether it actually works.

I have some early successes, but it's too soon to proclaim victory. Soon I'll probably also adapt it to a multinode workflow when #2301 is merged (or just snatch some code from there), because RL is sufficiently resource-hungry that single-node training isn't really an option for anything even moderately serious.

Right now the repo/PR is very messy and in a researchy state, to find something that works. Once it does work, I'll start cleaning it up to meet OSS standards. I'm putting it here to be able to keep track of the diffs and for potential discussions.

The rest will be filled when possible/relevant:


Changelog

What are the changes made in this PR?
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Feb 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2326

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2ba4a97 with merge base e6cba25 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 1, 2025
@ianbarber ianbarber mentioned this pull request Feb 1, 2025
8 tasks
@musabgultekin
Copy link
Contributor

musabgultekin commented Feb 3, 2025

Im extremely interested in this.
One extra suggestion is that we could technically run the frozen reference policy on another device through SGLang or vLLM. That way, we only hold the policy in the VRAM.
One potential drawback would be the inner implementations of these engines might cause subtle differences and that might cause issues on advantage&loss calculation.

RedTachyon and others added 4 commits February 3, 2025 17:42
Reorganize some recipes

Add SFT dataset
Clean up reward function

New (untested) generation function

New recipe config
@akashc1
Copy link
Contributor

akashc1 commented Feb 4, 2025

@RedTachyon thanks for the implementation! Please let me know if there's anything you'd like support on, I'm happy to help!

For my own use, could you please share how you set up the environment/commands to launch training?

@RedTachyon
Copy link
Contributor Author

Hi everyone, glad to see people interested in contributing to this implementation! I'm happy to say that the core implementation "just works" - on a recent run I did the following process:

  1. Take base model Llama 3B
  2. Train it with SFT for 1 epoch on 1/10th of the GSM8k train set, using the R1 prompt template (note: extremely quick training, took like 3 minutes on a single node)
  3. Continue training the resulting model on the remaining 9/10ths of the GSM8k train set

At the moment, everything follows the R1 paper relatively closely - format is cot 42, and reward computation is done by XML parsing the response, and checking the answer within tags for a perfect match (failed parse = 0 reward). There's some

Base model by itself struggles with the format (and knowing when to output <|eos|>), so its performance is just bad. The SFT-trained model follows the format well, but kinda sucks at math, so it tends to get about 10% success rate on GSM8k.

Continuing the training with GRPO, the model climbs up to a ~60% success rate! (and can probably go higher if it keeps running for longer)

image

Caveats: this estimate is based on the training set questions, without a separate eval - but a majority of the improvement happens before the first epoch finishes.

Which brings me to the list of things that still need to be done. I'll probably move it to the top comment on the PR later, but for now here it is:

  • A proper eval workflow, so that every once in a while we run a full eval on the test set
  • Adding proper (unit) tests, compliant with the normal torchtune testing workflow
  • Adding proper documentation to everything
  • General cleanup of messy research code
  • Refactoring of the GRPO losses (there's also a research-y question of what loss should be used - there's some ambiguity in the paper and reference implementations)
  • More modular approach to reward computation (probably as a component with the regular OmegaConf setup)
  • Probably more frequent checkpointing, not just tied to epochs (since epochs are sloooow with RL)
  • Memory profiling and experiments on "controlled" hardware (a recipe tuned to work ~optimally on a node of 8xH100, or on a single H100, or on smaller hardware with e.g. LoRA)
  • Dataset cleanup (gsm8k is functional but could use some polish, there's also the MATH dataset that's more challenging
  • Maybe a single-device version? (might be memory-sensitive)
  • Try to improve generation speed by using vLLM (or something else)?
  • Probably more to be found soon

I'll start going through some of these things, and if someone wants to contribute, please mention it so that we don't duplicate the effort.

Later today I'll also push my sbatch files to run the full pipeline so that anyone can run it on slurm - for single-node experiments, regular torchtune runner should work fine.

Regarding the organization: it's probably best to keep everything under a shared PR (i.e. this), and make other intermediate PRs into this branch. So various people can contribute parts of the GRPO setup, we merge it into this branch, and then when it's all done, the maintainers can review the complete thing and merge it into main.
It might be better to move the code to a branch on this repo instead of my fork - in any case, I'm happy to adapt it to whatever the maintainers think is best (@SalmanMohammadi @felipemello1 - not sure who's the right person to bother about the managementy things)

CC: @musabgultekin @akashc1 @ianbarber

@RedTachyon
Copy link
Contributor Author

@ebsmothers @SalmanMohammadi So I did a refactor and moved all "weird" stuff into /dev. Notably, I also took my changes to the generate function into a "fork" inside /dev, which admittedly creates some (temporary) code duplication, but hopefully it's an acceptable trade-off? The differences are multi-device support and allowing ignoring logits - we can take some more time to figure out the right APIs and tests for that, but for the time being, the logit issue really hurts the GRPO performance.

@ebsmothers
Copy link
Contributor

@RedTachyon thanks for doing that! At least at first glance this looks much easier for us to land, will give a proper review tomorrow.

@RedTachyon
Copy link
Contributor Author

@ebsmothers Gentle nudge - I'm hoping to have this merged sooner rather than later, so that I can start iterating on some more experimental stuff without getting into a branching nightmare

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience @RedTachyon! A few more small comments and I think this is good to go (now that everything is in dev I will not be as pedantic about some of the design considerations). Main thing is to make sure that everything is runnable out-of-the-box. Also if you're able to share some of the logged metrics (successes, rewards, etc) on your latest runs that would be great as well.

self._log_peak_memory_stats = False

self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this comment, I don't believe this is actually used anywhere? If so can we just remove it?

@felipemello1
Copy link
Contributor

hey @RedTachyon , just wanted to thank you for being so responsive and putting the effort into this. We all appreciate it :)

@RedTachyon
Copy link
Contributor Author

RedTachyon commented Feb 21, 2025

Happy to contribute, and thanks for all the help in bringing this to a publishable state!

I applied the final fixes and launched another run to make sure it still learns - so far it's just about the same, and I don't really expect anything to be different. (EDIT: about 2 hours in, it's going up the same way it was going up before, so it's most likely all good)

As for some existing metrics - right now I don't have anything very pretty, but just to give a sense of what to expect, here are some wandb graphs (which I unfortunately can't share as actual wandb reports at the moment).

There are 8 curves, varying across training from the base model or from an SFT-initialized model, and then across the number of nodes (so effectively batch size - 2 nodes have an effective batch size of 16, 4 nodes - 32, 8 nodes - 64).

Success rate vs steps
image

Reward vs steps
image

Success rate vs time
image

Reward vs time
image

Note that the hardware isn't super consistent, so the time graphs are meant as a very rough estimate

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @RedTachyon! This was a serious PR. Really appreciate your patience through the review process and we're so glad to have GRPO thanks to your efforts!

@ebsmothers ebsmothers merged commit cf0142b into pytorch:main Feb 21, 2025
17 checks passed
@RedTachyon RedTachyon mentioned this pull request Feb 22, 2025
15 tasks
joecummings pushed a commit to joecummings/torchtune that referenced this pull request Feb 27, 2025
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: salman <[email protected]>
joecummings pushed a commit to joecummings/torchtune that referenced this pull request Feb 27, 2025
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: salman <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.