-
Notifications
You must be signed in to change notification settings - Fork 990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding grpo training #1233
base: main
Are you sure you want to change the base?
Adding grpo training #1233
Conversation
Absolute HERO! Been trying to figure this out myself the past week but made pretty much no progress whatsoever, other than to make a script that fills up all the RAM on my Mac 🤣 Is there any way to run this yet? I assume no since at the mo it's still marked as in draft + there isn't a lora_config.yaml like in the DPO example yet (not sure if it's needed)? |
No, not yet I still have to implement the Dataset Wrapper and some other stuff, I'll tell you when it's done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible need to use expanded_prompts, expanded_answers in both reward and loss
python -m mlx_lm.lora \
--model Qwen/Qwen2.5-0.5B \
--train \
--data /Users/gokdenizgulmez/Desktop/test_grpo \
--iters 5 \
--batch-size 1 \
--num-layers 4 \
--val-batches 1 \
--steps-per-report 1 \
--adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
--max-seq-length 128 \
--grad-checkpoint \
--training-mode grpo \
--fine-tune-type lora \
--beta 0.1 \
--steps-per-eval 500 \
--group-size 2 Output
But after that my 32 GB of ram get fully used. I tried to add some memory optimisations but the memory usage is still too much. |
Still uses too much memory. |
So I tried using trl and the same amount of ram has been used, so no error on my side |
🚀 Would you be able to share the datasets you used for the training? Will give it a go on my machine as soon as I can 🙌 |
Will do that tomorrow 🤝 |
I created a quick one only for testing the code |
python -m mlx_lm.lora \
--model Qwen/Qwen2.5-0.5B \
--train \
--data /Users/gokdenizgulmez/Desktop/test_grpo \
--iters 5 \
--batch-size 1 \
--num-layers 8 \
--val-batches 1 \
--steps-per-report 1 \
--adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
--max-seq-length 255 \
--grad-checkpoint \
--training-mode grpo \
--fine-tune-type lora \
--beta 0.1 \
--steps-per-eval 500 \
--group-size 2 \
--max-completion-length 6 Output:
|
Just wanted to pop up again and express my support again 😁 The efforts are very much appreciated!!! (Been dealing with some personal issues lately so haven't been near as active in community as I'd like, but have been keeping an eye on this repo every day regardless pahahahaha) Thanks for the awesome work @Goekdeniz-Guelmez 😁 |
You only ran 250 iterations with batch size 1, which is likely insufficient for meaningful changes in model behavior, especially for a 3B parameter model. Can you also show me the logs from the training? If the rewards go up that means the mdoel is learning. Is the adapter path correct? Also is the system promtp you used correct? The default system prompt is |
thx @Goekdeniz-Guelmez , It's a problem with my system prompt, I try to use
|
Thanks!! It should NOT be indented because it should execute regardless of whether the weights were provided or defaulted. |
@Goekdeniz-Guelmez i am testing on the latest commit (first of all, again, amazing improvements) i was running training for
i'm saving adapters every 10 steps, so this wasnt the first time they were being saved... not really sure what else i can add about this unfortunately it didnt print a stacktrace |
@deathcoder Probably has something to do with the memory handling and clearing, I'll look into it when im home. |
Hello, Thanks a lot for all the work, it is a pleasure to be able to play with GRPO locally !! After a few tests, it seems to work perfectly fine for very short prompts, but I struggle with prompts of 1000 tokens, even with small models like mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit My args :
With the model Qwen/Qwen2.5-0.5B, I get 5 tokens/sec, so with a group-size of 2 generating 500 tokens = 1000 tokens generated => 200 sec per iteration I have a M4 max 128gb : my gpu usage peak at 10% sometimes, but stay very low. The memory used is close to 40 BG with the 1.5B @Goekdeniz-Guelmez, do you have any idea how to improve the performance ? |
@Vi-cs have you tried reducing the num-layers? you are tuning all layers with -1, also make sure you are on the latest commit i never tried with -1, but in my tests with 8 layers i get much higher speeds than that,
|
My commit was a few days old but I pulled the last commit just to be sure. And I changed to --num-layer 4. Edit :
|
not sure what it is on your side that is slowing you down, i just ran the exact same command you just sent, only difference is the dataset:
and training is still going i'm now on iteration 14 and i just launched it before i start writing this message edit: just to confirm you are actually on latest commit, do you have the validation sample details in your logs?
|
@Goekdeniz-Guelmez Thanks for the nice job! I wonder if the current codes support GRPO training with |
Thanks @deathcoder. The dataset is this one : https://huggingface.co/datasets/vi-c/test/.
Also, I think I am on the last commit based on le log :
|
Hi @Goekdeniz-Guelmez , I am using a R1 distill model which is not an instruct model. It doesn't behave correctly with --use-chat-template. I set the --use-prompt on purpose and it works fine (I tested this on Unsloth GRPO). |
@Vi-cs You're using the wrong Dataset! Your dataset doesn't match the normal reward functions. The reward functions are looking for specific XML tags and formatted answers, but your dataset contains JSON with Chinese text instead. If you really need to, then you have create new reward functions and prompt that works with the JSON data, and train your model via code. The Dataset is not suited for GRPO training since GRPO needs structured data with clear evaluation criteria to optimize against, which your mixed-language JSON data doesn't provide. Look into the LORA.md documentation to understand the dataset that should be used, your dataset is usually more suited for basic SFT training and not for GRPO. Bad Dataset![]() Good Dataset:Goastro/mlx-grpo-dataset |
Totally agree! The reward functions provided only reward the model if the completion contains a strict xml structure, an Int inside the answer tag, and the correct int. The training doesn't work with my reward functions. Just to make sure the issue is not related to my custom reward functions, I use the ones provided (which don't make sense with my dataset, but should not break the training) Still, the training doesn't work. @Goekdeniz-Guelmez do you see a reason why the training would not succeed in computing iterations with my dataset? @deathcoder Any chance you try to execute my args (with my dataset), to see if on you side you process at least a few itérations? Thanks! |
The main issue wit long prompts is that transformer models process text with like quadratic complexity - meaning 10x longer prompts take 100x longer to process. This problem is the same in GRPO becus we generate multiple completions per prompt (group_size) and compare them all against eachother. Each token generated depends on processing all previous tokens, creating a strictly sequential process, especially with json data that requires complete structural accuraccy. Also how did u train the model using your reward functions and what do your reward function look like? My idea of some reward functions is like json_structure_reward that reward proper JSON formatting and structure and then financial_data_reward that reward accuracy of financial data compared to reference, also the chat template meeds to change too to kick start the json output format. you also have to write a function to extract the model answer correctly. Also since your using def train_grpo(
model: nn.Module,
ref_model: Optional[nn.Module],
tokenizer,
optimizer,
train_dataset,
val_dataset,
reward_funcs: Optional[List[RewardFunctions]] = [
json_structure_reward,
financial_data_reward
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None,
): |
That is exactly what I have done : new rewards functions (added in grpo_reward_functions.py and imported into grpo_trainer.py), to parse the output and validate the format and the content of the outputed json. OK did not understood that the consequence of the --use-prompt is that the model doesn't generates EOS. When running the same code on Unsloth GRPO (same reward functions, same dataset), the model stops right after the json generation. I don't understand why yet. I think I will create a dataset with increasing prompt length to get a better feeling of the limit where is starting to be really slow. |
Seems to be related to the model used. This works perfectly
This doesn't work :
Any Idea @Goekdeniz-Guelmez |
i think it's what @Goekdeniz-Guelmez already said before, the instruct model will recognize the pattern it has been finetuned on and output the eos token and stop generating, this speeds up iterations a lot compared to a model that will just keep going to the full completion len, if you want to optimize training for R1-Distill you need to check deepseek's recommendations on usage in their model card https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B i believe Qwen wants chatml prompts, so that would tipically look like this:
and if you use the correct format, the model should output the eos token... |
Yes it also depends on the model, what you should do is what Deepseek did for R1, they frst cold started the model so first SFT finetuned own a small amount of that dataset with the custom format, and THEN use GRPO. You CAN NOT just train a model in RL without pointing to the direction first. Deep seek is not trained to output json formats, and system prompting wont work either. where ass wen is a generalist model so system promtp would guide the model to output the correct stuff. So I tried it with the datasets and it works for me.
|
Using the same dataset with Owen we see that it generates the correct format with the correct tokens:
|
No description provided.