You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The recipe definitely works (as in, I can run it and reach like a 60% success rate on GSM8k with a 3B model), but it's somewhat barebones and underoptimized. Here, I want to keep track of all the most important features and improvements that I think are missing. I'll probably go through this at some point, at some pace, but if anyone else wants to contribute - you can grab something from this list.
Improvements
Figure out how to move things from dev into the main repository (decide on the final APIs etc.)
A proper eval workflow, so that every once in a while we run a full eval on the test set. Alternatively, a separate evaluation recipe? (note: I actually have a working separate eval recipe that works with the same paradigm as GRPO, goes through the full test dataset, computes success/reward and saves it to a file inside the checkpoint - happy to make it a PR)
Adding proper (unit) tests, compliant with the normal torchtune testing workflow
Adding proper documentation to everything
Refactoring of the GRPO losses (there's also a research-y question of what loss should be used - there's some ambiguity in the paper and reference implementations)
More modular approach to reward computation (probably as a component with the regular OmegaConf setup)
Step-based checkpointing (Implement step based checkpointing #2384) - this is pretty important, since one epoch can be very long, leading to very infrequent checkpoints
Memory profiling and experiments on "controlled" hardware (a recipe tuned to work ~optimally on a node of 8xH100, or on a single H100, or on smaller hardware with e.g. LoRA)
Optimization of the default recipe - maybe we can get a big performance boost e.g. by doing ppo_epochs>1?
Dataset improvements (gsm8k is functional but could use some polish, there's also the MATH and DeepscaleR datasets with a similar structure that can be added)
A single-device version - should be pretty simple, but probably also slow. Might require gradient accumulation to properly work, I tend to get bad results with small batch sizes.
Try to improve generation speed by using vLLM (or something else)?
Probably more to be found soon
Bugs
Because of course I found a bug right after everything was finalized. There will likely be more, so this subsection might or might not be useful.
Right now, generate_trajectory_batched can crash when the different generations are of different size. For example, one batch of completions generated the full 512 tokens, but another one got truncated at 300 because it hit a stop token everywhere. So you have tensors of shapes [16, 512] and [16, 300], and try to concatenate them across zero-th axis - which obviously doesn't work. The tensors need to be padded to consistent length.
Very, very rarely, it seems that an invalid token is sampled - for example token 128011, which is an undefined special token with the standard config. When we try to decode this for the reward computation, the entire program crashes because tiktoken can't handle the unknown token. This can probably be handled by replacing undefined generated tokens with pad_id or something. As to why these tokens are ever sampled - the model probably gives them a very low probability, say 1e-7, but if you sample a new token 1e7 times, chances are, it will happen at some point.
Note to maintainers - I took the liberty to create this centralized checklist since I still have all the necessary improvements in my context window. In principle, each bullet point could be a separate issue, but that would probably be a nightmare. We can coordinate the effort around this issue, and start adding the improvements, one PR at a time.
The text was updated successfully, but these errors were encountered:
Thank you so much for creating this checklist @RedTachyon! It's great to have all these items in one place. Actually a couple of the improvements you listed are horizontal changes we've been wanting to enable across the repo anyways -- I'm thinking specifically of vLLM, eval datasets, and step-based checkpointing. Step-based checkpointing is already in progress and I think some basic eval shouldn't be too hard (I think #2238 was going in that direction and may just need some minor changes). Proper vLLM integration may be a bigger effort, but we are planning to get going on this asap. And thanks @krammnic for already working on a couple of these!
It's alive! It's alive! (#2326)
The recipe definitely works (as in, I can run it and reach like a 60% success rate on GSM8k with a 3B model), but it's somewhat barebones and underoptimized. Here, I want to keep track of all the most important features and improvements that I think are missing. I'll probably go through this at some point, at some pace, but if anyone else wants to contribute - you can grab something from this list.
Improvements
ppo_epochs>1
?Bugs
Because of course I found a bug right after everything was finalized. There will likely be more, so this subsection might or might not be useful.
Note to maintainers - I took the liberty to create this centralized checklist since I still have all the necessary improvements in my context window. In principle, each bullet point could be a separate issue, but that would probably be a nightmare. We can coordinate the effort around this issue, and start adding the improvements, one PR at a time.
The text was updated successfully, but these errors were encountered: