Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding grpo training #1233

Open
wants to merge 77 commits into
base: main
Choose a base branch
from

Conversation

Goekdeniz-Guelmez
Copy link
Contributor

No description provided.

@mark-lord
Copy link

mark-lord commented Feb 2, 2025

Absolute HERO! Been trying to figure this out myself the past week but made pretty much no progress whatsoever, other than to make a script that fills up all the RAM on my Mac 🤣

Is there any way to run this yet? I assume no since at the mo it's still marked as in draft + there isn't a lora_config.yaml like in the DPO example yet (not sure if it's needed)?

@Goekdeniz-Guelmez
Copy link
Contributor Author

No, not yet I still have to implement the Dataset Wrapper and some other stuff, I'll tell you when it's done.

Copy link

@Guo-astro Guo-astro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible need to use expanded_prompts, expanded_answers in both reward and loss

@Goekdeniz-Guelmez
Copy link
Contributor Author

python -m mlx_lm.lora \
    --model Qwen/Qwen2.5-0.5B \
    --train \
    --data /Users/gokdenizgulmez/Desktop/test_grpo \
    --iters 5 \
    --batch-size 1 \
    --num-layers 4 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
    --max-seq-length 128 \
    --grad-checkpoint \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2

Output

Loading pretrained model
Fetching 7 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 124936.71it/s]
Loading datasets
Training
Trainable parameters: 0.109% (0.541M/494.033M)
Starting GRPO training with 5 reward functions..., iters: 5
[WARNING] Some prompts are longer than 128 tokens. Long prompts will be truncated.
Iter 1: Val loss 0.00000140, Val total_rewards_mean -0.359, Val total_rewards_std 0.010, Val grouped_rewards_mean -0.359, Val grouped_rewards_std 0.010, Val kl 0.000, Val reward_func_0_mean 0.000, Val reward_func_0_std 0.000, Val reward_func_1_mean 0.000, Val reward_func_1_std 0.000, Val reward_func_2_mean 0.000, Val reward_func_2_std 0.000, Val reward_func_3_mean 0.000, Val reward_func_3_std 0.000, Val reward_func_4_mean -1.794, Val reward_func_4_std 0.051, Val took 8.385s

But after that my 32 GB of ram get fully used. I tried to add some memory optimisations but the memory usage is still too much.

@Goekdeniz-Guelmez
Copy link
Contributor Author

Iter 1: Val loss -0.00000057, Val total_rewards_mean -0.387, Val total_rewards_std 0.026, Val grouped_rewards_mean -0.387, Val grouped_rewards_std 0.026, Val kl 0.000, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean -1.937, Val r1_count_xml_std 0.128, Val took 8.314s

Still uses too much memory.

@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 3, 2025

So I tried using trl and the same amount of ram has been used, so no error on my side

@mark-lord
Copy link

🚀

Would you be able to share the datasets you used for the training? Will give it a go on my machine as soon as I can 🙌

@Goekdeniz-Guelmez
Copy link
Contributor Author

Will do that tomorrow 🤝

@Guo-astro
Copy link

🚀

Would you be able to share the datasets you used for the training? Will give it a go on my machine as soon as I can 🙌

I created a quick one only for testing the code

https://huggingface.co/datasets/Goastro/mlx-grpo-dataset

@Goekdeniz-Guelmez
Copy link
Contributor Author

python -m mlx_lm.lora \
    --model Qwen/Qwen2.5-0.5B \
    --train \
    --data /Users/gokdenizgulmez/Desktop/test_grpo \
    --iters 5 \
    --batch-size 1 \
    --num-layers 8 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/gokdenizgulmez/Desktop/test-grpo-full \
    --max-seq-length 255 \
    --grad-checkpoint \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --max-completion-length 6

Output:

Loading pretrained model
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 72853.92it/s]
Loading datasets
Training
Trainable parameters: 0.109% (0.541M/494.033M)
Fetching 7 files: 100%|███████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10955.27it/s]
Starting GRPO training with 5 reward functions..., iters: 5
Iter 1: Val loss 0.00000000, Val total_rewards_mean -0.354, Val total_rewards_std 0.012, Val grouped_rewards_mean -0.354, Val grouped_rewards_std 0.012, Val kl 0.000, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean -1.769, Val r1_count_xml_std 0.060, Val took 26.298s
Iter 1: Train loss -0.00001353, Total rewards mean -0.306, Total rewards std 0.001, Grouped rewards mean -0.306, Grouped rewards std 0.001, KL 0.000, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -1.532, r1_count_xml std 0.005, Learning Rate 1.000e-05, It/sec 0.079, Tokens/sec 25.072, Peak mem 7.254 GB
Iter 2: Train loss 0.00055540, Total rewards mean -0.572, Total rewards std 0.001, Grouped rewards mean -0.572, Grouped rewards std 0.001, KL 0.006, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -2.861, r1_count_xml std 0.005, Learning Rate 1.000e-05, It/sec 0.121, Tokens/sec 36.164, Peak mem 7.254 GB
Iter 3: Train loss 0.00070858, Total rewards mean -0.842, Total rewards std 0.003, Grouped rewards mean -0.842, Grouped rewards std 0.003, KL 0.013, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -4.210, r1_count_xml std 0.013, Learning Rate 1.000e-05, It/sec 0.110, Tokens/sec 31.790, Peak mem 7.254 GB
Iter 4: Train loss 0.00070563, Total rewards mean -1.161, Total rewards std 0.005, Grouped rewards mean -1.161, Grouped rewards std 0.005, KL 0.020, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -5.806, r1_count_xml std 0.024, Learning Rate 1.000e-05, It/sec 0.105, Tokens/sec 36.961, Peak mem 7.899 GB
Iter 5: Val loss 0.00057772, Val total_rewards_mean -0.345, Val total_rewards_std 0.005, Val grouped_rewards_mean -0.345, Val grouped_rewards_std 0.005, Val kl 0.006, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean -1.726, Val r1_count_xml_std 0.025, Val took 22.624s
Iter 5: Train loss 0.00059050, Total rewards mean -1.399, Total rewards std 0.006, Grouped rewards mean -1.399, Grouped rewards std 0.006, KL 0.026, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean -6.994, r1_count_xml std 0.029, Learning Rate 1.000e-05, It/sec 0.156, Tokens/sec 39.539, Peak mem 7.899 GB
Saved final weights to /Users/gokdenizgulmez/Desktop/test-grpo-full/adapters.safetensors.

@mark-lord
Copy link

mark-lord commented Feb 4, 2025

🥳🥳🥳

Working on my machine too! Not to mention it's plug-and-play with QLoRA as well, which I don't think TRL even has 😁 And already used it to get an 'aha' moment out of Phi-14b and do some knowledge injection 🚀 [Edit: I did not get it to work properly - see later in conversation] Screenshot 2025-02-04 at 02 10 40

@mark-lord
Copy link

This really motivates me to know that my efforts are appreciated and that there's a clear desire within the community for these enhancements

Just wanted to pop up again and express my support again 😁 The efforts are very much appreciated!!! (Been dealing with some personal issues lately so haven't been near as active in community as I'd like, but have been keeping an eye on this repo every day regardless pahahahaha) Thanks for the awesome work @Goekdeniz-Guelmez 😁

@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 26, 2025

You only ran 250 iterations with batch size 1, which is likely insufficient for meaningful changes in model behavior, especially for a 3B parameter model. Can you also show me the logs from the training? If the rewards go up that means the mdoel is learning. Is the adapter path correct? Also is the system promtp you used correct? The default system prompt is A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. or try training again but the pretrained version if you do so, then use --use-prompt instead of --use-chat-template. I'll try it out too with your settings when im Home.

@wangcheng0825
Copy link

You only ran 250 iterations with batch size 1, which is likely insufficient for meaningful changes in model behavior, especially for a 3B parameter model. Can you also show me the logs from the training? If the rewards go up that means the mdoel is learning. Is the adapter path correct? Also is the system promtp you used correct? The default system prompt is A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. or try training again but the pretrained version if you do so, then use --use-prompt instead of --use-chat-template. I'll try it out too with your settings when im Home.

thx @Goekdeniz-Guelmez , It's a problem with my system prompt, I try to use A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. this prompt, the answer looks correct, so i delete my question。However, thank you very much for your reply.

==========
<think> John feeds each horse 20 pounds of food twice a day, so each horse consumes 20 * 2 = 40 pounds of food per day. With 25 horses, the total daily food consumption is 25 * 40 = 1000 pounds. Over 60 days, the total food consumption is 1000 * 60 = 60000 pounds. Since John buys half-ton bags of food, each bag contains 1000 pounds of food. Therefore, the number of bags needed is 60000 / 1000 = 60 bags. </think>
<answer> 60 </answer>
==========
Prompt: 146 tokens, 990.388 tokens-per-sec
Generation: 153 tokens, 49.867 tokens-per-sec
Peak memory: 12.345 GB

@SfcFromSx
Copy link

截屏2025-02-26 下午9 57 23
maybe this place need a indentation?

@Goekdeniz-Guelmez
Copy link
Contributor Author

Thanks!! It should NOT be indented because it should execute regardless of whether the weights were provided or defaulted.

@deathcoder
Copy link

@Goekdeniz-Guelmez i am testing on the latest commit (first of all, again, amazing improvements) i was running training for
mlx-community/Qwen2.5-7B-Instruct-8bit model and after 80 steps, just after it saved adapters it got stuck, gpu usage dropped to 0, and couldnt ctrl-c out of it, in the end i had to kill the process and i got this message after i did that:

Iter 80: Train loss -0.002, Total rewards mean 6.854, Total rewards std 0.954, Grouped rewards mean 6.854, Grouped rewards std 0.954, KL 0.217, r1_chess_reward_func mean 6.854, r1_chess_reward_func std 0.954, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, Learning Rate 1.000e-05, It/sec 0.016, Tokens/sec 14.979, Peak mem 31.438 GB
Iter 80: Saved adapter weights to adapters/chess_small/adapters.safetensors and adapters/chess_small/0000080_adapters.safetensors.

/Users/admin/devtools/miniconda3/envs/mlx-grpo/lib/python3.12/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

i'm saving adapters every 10 steps, so this wasnt the first time they were being saved... not really sure what else i can add about this unfortunately it didnt print a stacktrace

@Goekdeniz-Guelmez
Copy link
Contributor Author

@deathcoder Probably has something to do with the memory handling and clearing, I'll look into it when im home.

@Vi-cs
Copy link

Vi-cs commented Feb 27, 2025

Hello,

Thanks a lot for all the work, it is a pleasure to be able to play with GRPO locally !!

After a few tests, it seems to work perfectly fine for very short prompts, but I struggle with prompts of 1000 tokens, even with small models like mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit

My args :

python -m mlx_lm.lora \
    --model mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit \
    --train \
    --data /dataset/ # prompt of 1000 tokens 
    --iters 100 \
    --batch-size 1 \
    --num-layers -1 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/test4/ \
    --max-seq-length 2000 \
    --max-completion-length 1000 \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --use-prompt

With the model Qwen/Qwen2.5-0.5B, I get 5 tokens/sec, so with a group-size of 2 generating 500 tokens = 1000 tokens generated => 200 sec per iteration
With mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit, iteration 1 is still not completed after 45 min

I have a M4 max 128gb : my gpu usage peak at 10% sometimes, but stay very low. The memory used is close to 40 BG with the 1.5B

@Goekdeniz-Guelmez, do you have any idea how to improve the performance ?

@deathcoder
Copy link

@Vi-cs have you tried reducing the num-layers? you are tuning all layers with -1, also make sure you are on the latest commit i never tried with -1, but in my tests with 8 layers i get much higher speeds than that,

  • on 32B-8bit i get 7/8 toks/s
  • on 7B-8bit is closer to 20 toks/s

@Vi-cs
Copy link

Vi-cs commented Feb 27, 2025

My commit was a few days old but I pulled the last commit just to be sure. And I changed to --num-layer 4.
Still not able to complete the first iteration :/
The gpu usage remain very low.

Edit :
I relaunched the same command and get first iteration in a few minutes then nothing for 20 min.

Iter 1: Val loss 0.000, Val total_rewards_mean 0.062, Val total_rewards_std 0.062, Val grouped_rewards_mean 0.062, Val grouped_rewards_std 0.062, Val kl 0.000, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean 0.062, Val r1_count_xml_std 0.062, Val took 7.940s
Iter 1: Train loss 0.000, Total rewards mean 0.125, Total rewards std 0.125, Grouped rewards mean 0.125, Grouped rewards std 0.125, KL 0.000, r1_accuracy_reward_func mean 0.000, r1_accuracy_reward_func std 0.000, r1_int_reward_func mean 0.000, r1_int_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean 0.125, r1_count_xml std 0.125, Learning Rate 1.000e-05, It/sec 0.023, Tokens/sec 11.888, Peak mem 40.836 GB
python -m mlx_lm.lora \
    --model mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit \
    --train \
    --data vi-c/test\
    --iters 100 \
    --batch-size 1 \
    --num-layers 4 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/viviencuisinier/Github/mlx-examples/test4/ \
    --max-seq-length 2000 \
    --max-completion-length 1000 \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --use-prompt

@deathcoder
Copy link

deathcoder commented Feb 27, 2025

not sure what it is on your side that is slowing you down, i just ran the exact same command you just sent, only difference is the dataset:

Iter 1: Val loss 0.000, Val total_rewards_mean 0.050, Val total_rewards_std 0.050, Val grouped_rewards_mean 0.050, Val grouped_rewards_std 0.050, Val kl 0.000, Val r1_chess_reward_func_mean 0.050, Val r1_chess_reward_func_std 0.050, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val took 7.870s
Iter 1: Train loss 0.000, Total rewards mean 0.000, Total rewards std 0.000, Grouped rewards mean 0.000, Grouped rewards std 0.000, KL 0.000, r1_chess_reward_func mean 0.000, r1_chess_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, Learning Rate 1.000e-05, It/sec 0.244, Tokens/sec 56.339, Peak mem 5.457 GB
Iter 2: Train loss 0.000, Total rewards mean 0.000, Total rewards std 0.000, Grouped rewards mean 0.000, Grouped rewards std 0.000, KL 0.000, r1_chess_reward_func mean 0.000, r1_chess_reward_func std 0.000, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, Learning Rate 1.000e-05, It/sec 0.251, Tokens/sec 55.909, Peak mem 5.457 GB

and training is still going i'm now on iteration 14 and i just launched it before i start writing this message

edit: just to confirm you are actually on latest commit, do you have the validation sample details in your logs?
for me it looks like this at the very start of the process:

Starting GRPO training with 3 reward functions..., iters: 100

=== Validation Sample Details ===

📝 Generation:
 ```<answer>2</answer>``.<|im_end|>
</think>

The analyzing program started by considering UCI notation moves.

Moving from Player 1's captures, it moves Option 2 from Player 2 including the Queen back into check, to return to Player 1's line.

The成熟 male pawn component doubles up, gave Player 2 a pawn supported by a undefended bishop.

But the other bishop when under a sometimes prompted positional control.

But ultimately, Player 2 counter-pumpes and controls the game with deep development of promote capturing a way for Player 1 to accept a successful Minimal One Move```

<answer>2</answer>

==========


✅ Answer:
['c8b7']

==========


🔍 Extracted Answer:
2

==============================

@wxjiao
Copy link

wxjiao commented Feb 28, 2025

Args:

python -m mlx_lm.lora \
    --model Qwen/Qwen2.5-3B-Instruct \
    --train \
    --data /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/MLX/test_grpo \
    --iters 100 \   
    --batch-size 1 \                                                                                                  
    --num-layers 8 \       
    --val-batches 1 \  
    --steps-per-report 1 \
    --adapter-path /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/MLX/test-grpo-full \
    --max-seq-length 1024 \
    --grad-checkpoint \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \                 
    --steps-per-eval 50 \
    --test \
    --test-batches 1 \
    --group-size 2 \
    --max-completion-length 512 \
    --use-chat-template

@Goekdeniz-Guelmez Thanks for the nice job! I wonder if the current codes support GRPO training with --batch-size > 1?

@Vi-cs
Copy link

Vi-cs commented Feb 28, 2025

Thanks @deathcoder.

The dataset is this one : https://huggingface.co/datasets/vi-c/test/.

python -m mlx_lm.lora \
    --model mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit \
    --train \
    --data vi-c/test\
    --iters 100 \
    --batch-size 1 \
    --num-layers 4 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /Users/test4/ \
    --max-seq-length 2000 \
    --max-completion-length 1000 \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --use-prompt

Also, I think I am on the last commit based on le log :

Starting GRPO training with 5 reward functions..., iters: 100

=== Validation Sample Details ===

📝 Generation:
获得感:.eml中的copied to CV中的相关信息,具体包括Alexandra和Bill的收入情况,专业技能和经验和公司经历等。

符合这个JSON schema的contact对象信息如下:

{
"domicile": "CH",
"birth_country": "CH",
"short_description": "Alexandra",
"email": "[email protected]",
"phone": "555-123-4567",
"birth_date": "1985-03-20",
"zip_code": "80512",
"occupation": "工商而乘",
"职业学位": null,
"长期兼职职位": null,
"短期兼职职位": "行政 Assistant",
"first_name": "Bill",
"last_name": "",
"nationalities": ["CH", "US", "IT"],
"type": "person",
"asset": {
"cash": { "value": "CHF 6.55M", "currency": "CHF", "description": "CHF cash portfolio" },
"division投资": { "value": "CHF 2.35M", "currency": "CHF", "description": "投资组合成果" },
"real estate": { "value": "CHF 2.5M", "currency": "CHF", "description": "房地产购买" },
"其他": { "value": "CHF 2.2M", "currency": "CHF", "description": "其他资产" }
},
"source_of_wealth": "通过分配收益,依据题目中的公司系统开发`

{
  "domicile": "CH",
  "birth_country": "CH",
  "short_description": "Alexandra",
  "email": "[email protected]",
  "phone": "555-123-4567",
  "birth_date": "1985-03-20",
  "zip_code": "80512",
  "occupation": "工商而乘",
  "职业学位": null,
  "长期兼职职位": null,
  "短期兼职职位": "行政 Assistant",
  "first_name": "Bill",
  "last_name": "",
  "nationalities": ["CH", "US", "IT"],
  "type": "person",
  "asset": {
    "cash": { "value": "CHF 6.55M", "currency": "CHF", "description": "CHF cash portfolio" },
    "division投资": { "value": "CHF 2.35M", "currency": "CHF", "description": "投资组合成果" },
    "real estate": { "value": "CHF 2.5M", "currency": "CHF", "description": "房地产购买" },
    "其他": { "value": "CHF 2.2M", "currency": "CHF", "description": "其他资产" }
  },
  "source_of_wealth": "通过分配收益,依据题目中的公司系统开发"
}

==========

✅ Answer:
{
"contacts": [
{
"id": null,
"domicile": "CHE",

@Goekdeniz-Guelmez
Copy link
Contributor Author

@Vi-cs with --use-prompt your not applying the chat template, so try using --use-chat-template instead, you can take a look at the LORA.md documentation file. @wxjiao thanks, and yes this should also work with batching.

@Vi-cs
Copy link

Vi-cs commented Feb 28, 2025

Hi @Goekdeniz-Guelmez , I am using a R1 distill model which is not an instruct model. It doesn't behave correctly with --use-chat-template. I set the --use-prompt on purpose and it works fine (I tested this on Unsloth GRPO).

@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 28, 2025

@Vi-cs You're using the wrong Dataset! Your dataset doesn't match the normal reward functions. The reward functions are looking for specific XML tags and formatted answers, but your dataset contains JSON with Chinese text instead. If you really need to, then you have create new reward functions and prompt that works with the JSON data, and train your model via code. The Dataset is not suited for GRPO training since GRPO needs structured data with clear evaluation criteria to optimize against, which your mixed-language JSON data doesn't provide. Look into the LORA.md documentation to understand the dataset that should be used, your dataset is usually more suited for basic SFT training and not for GRPO.

Bad Dataset

Bildschirmfoto 2025-02-28 um 11 13 11

Good Dataset:

Goastro/mlx-grpo-dataset

Bildschirmfoto 2025-02-28 um 11 55 49

@Vi-cs
Copy link

Vi-cs commented Feb 28, 2025

Totally agree!

The reward functions provided only reward the model if the completion contains a strict xml structure, an Int inside the answer tag, and the correct int.
These functions did not match my dataset so I wrote new functions.

The training doesn't work with my reward functions.

Just to make sure the issue is not related to my custom reward functions, I use the ones provided (which don't make sense with my dataset, but should not break the training)

Still, the training doesn't work.

@Goekdeniz-Guelmez do you see a reason why the training would not succeed in computing iterations with my dataset?

@deathcoder Any chance you try to execute my args (with my dataset), to see if on you side you process at least a few itérations?

Thanks!

@Goekdeniz-Guelmez
Copy link
Contributor Author

The main issue wit long prompts is that transformer models process text with like quadratic complexity - meaning 10x longer prompts take 100x longer to process. This problem is the same in GRPO becus we generate multiple completions per prompt (group_size) and compare them all against eachother. Each token generated depends on processing all previous tokens, creating a strictly sequential process, especially with json data that requires complete structural accuraccy. Also how did u train the model using your reward functions and what do your reward function look like?

My idea of some reward functions is like json_structure_reward that reward proper JSON formatting and structure and then financial_data_reward that reward accuracy of financial data compared to reference, also the chat template meeds to change too to kick start the json output format. you also have to write a function to extract the model answer correctly. Also since your using use-promtp the model is literally generating to the max tokens because this, (as I said before) not apply the chat template so the model doesnt generate the EOS token. So basically it just keeps going until it hits the token limit instead of naturally stopping. That would definitely explain why it's taking so long - it's generating way more tokens than needed for each completion AND accumulating it. With your custom reward functions you can't just call the training in the terminal, you have to do it with a python file like

def train_grpo(
    model: nn.Module,
    ref_model: Optional[nn.Module],
    tokenizer,
    optimizer,
    train_dataset,
    val_dataset,
    reward_funcs: Optional[List[RewardFunctions]] = [
        json_structure_reward,
        financial_data_reward
    ],
    args: GRPOTrainingArgs = GRPOTrainingArgs(),
    loss_fn: callable = grpo_loss,
    iterate_batches: callable = iterate_grpo_batches,
    training_callback: TrainingCallback = None,
):

@Vi-cs
Copy link

Vi-cs commented Feb 28, 2025

That is exactly what I have done : new rewards functions (added in grpo_reward_functions.py and imported into grpo_trainer.py), to parse the output and validate the format and the content of the outputed json.

OK did not understood that the consequence of the --use-prompt is that the model doesn't generates EOS. When running the same code on Unsloth GRPO (same reward functions, same dataset), the model stops right after the json generation. I don't understand why yet.

I think I will create a dataset with increasing prompt length to get a better feeling of the limit where is starting to be really slow.

@Vi-cs
Copy link

Vi-cs commented Feb 28, 2025

Seems to be related to the model used.

This works perfectly

python -m mlx_lm.lora \
    --model **mlx-community/Qwen2.5-1.5B-Instruct-bf16** \
    --train \
    --data Goastro/mlx-grpo-dataset\
    --iters 100 \
    --batch-size 1 \
    --num-layers 4 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /test4/ \
    --max-seq-length 2000 \
    --max-completion-length 1000 \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --use-chat-template

This doesn't work :

python -m mlx_lm.lora \
    --model **mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit** \
    --train \
    --data Goastro/mlx-grpo-dataset\
    --iters 100 \
    --batch-size 1 \
    --num-layers 4 \
    --val-batches 1 \
    --steps-per-report 1 \
    --adapter-path /test4/ \
    --max-seq-length 2000 \
    --max-completion-length 1000 \
    --training-mode grpo \
    --fine-tune-type lora \
    --beta 0.1 \
    --steps-per-eval 500 \
    --group-size 2 \
    --use-chat-template

Any Idea @Goekdeniz-Guelmez

@deathcoder
Copy link

i think it's what @Goekdeniz-Guelmez already said before, the instruct model will recognize the pattern it has been finetuned on and output the eos token and stop generating, this speeds up iterations a lot compared to a model that will just keep going to the full completion len, if you want to optimize training for R1-Distill you need to check deepseek's recommendations on usage in their model card https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

i believe Qwen wants chatml prompts, so that would tipically look like this:

<|im_start|>user
Hi there!<|im_end|>
<|im_start|>assistant
Nice to meet you!<|im_end|>
<|im_start|>user
Can I ask a question?<|im_end|>

and if you use the correct format, the model should output the eos token...
i think atm we dont log the prompt in the validation step, maybe this could be a small improvement that makes it more visible if we are using the correct format

@Goekdeniz-Guelmez
Copy link
Contributor Author

Goekdeniz-Guelmez commented Feb 28, 2025

Yes it also depends on the model, what you should do is what Deepseek did for R1, they frst cold started the model so first SFT finetuned own a small amount of that dataset with the custom format, and THEN use GRPO. You CAN NOT just train a model in RL without pointing to the direction first. Deep seek is not trained to output json formats, and system prompting wont work either. where ass wen is a generalist model so system promtp would guide the model to output the correct stuff. So I tried it with the datasets and it works for me.

=== Validation Sample Details ===

📝 Generation:
Okay, so I've got this problem here: there are 25 roses, 40 tulips, and 35 daisies in a garden. I need to find out what percentage of the flowers are not roses. Hmm, let's see. First, I think I should figure out the total number of flowers in the garden. That makes sense because if I know how many flowers there are in total and how many are roses, I can subtract the roses to find out how many are not roses.

Alright, so I have 25 roses, 40 tulips, and 35 daisies. I should add those up to get the total number of flowers. Let's do the math: 25 plus 40 is 65, and then 65 plus 35 gives me 100. Okay, so there are 100 flowers in total. That simplifies things a bit because percentages often deal with a base of 100.

Now, since the total number of flowers is 100, the number of flowers that are not roses will be the total minus the roses. So that would be 100 minus 25. Let me subtract that: 100 minus 25 is 75. So, there are 75 flowers that are not roses.

But the question asks for the percentage, not the absolute number. So I need to convert 75 into a percentage of the total, which is 100. To find the percentage, I can set up a fraction where 75 is the numerator and 100 is the denominator. So that's 75 divided by 100.

Calculating that, 75 divided by 100 equals 0.75. Now, to convert this decimal to a percentage, I multiply by 100. So, 0.75 times 100 is 75%. Therefore, 75% of the flowers in the garden are not roses.

Wait, let me double-check to make sure I didn't make a mistake. Total flowers: 25 + 40 + 35 = 100. Not roses: 40 + 35 = 75. So, 75 out of 100 is indeed 75%. Yeah, that seems right. I think I've got it.
</think>

The total number of flowers is 25 roses + 40 tulips + 35 daisies = 100 flowers. The number of flowers that are not roses is 40 + 35 = 75. To find the percentage, we calculate (75/100) × 100% = 75%. 

<answer> 75% </answer>

==========


✅ Answer:
75

==========


🔍 Extracted Answer:
75%

===================================

Iter 1: Val loss 0.000, Val total_rewards_mean 0.188, Val total_rewards_std 0.188, Val grouped_rewards_mean 0.188, Val grouped_rewards_std 0.188, Val kl 0.000, Val r1_accuracy_reward_func_mean 0.000, Val r1_accuracy_reward_func_std 0.000, Val r1_int_reward_func_mean 0.000, Val r1_int_reward_func_std 0.000, Val r1_strict_format_reward_func_mean 0.000, Val r1_strict_format_reward_func_std 0.000, Val r1_soft_format_reward_func_mean 0.000, Val r1_soft_format_reward_func_std 0.000, Val r1_count_xml_mean 0.188, Val r1_count_xml_std 0.188, Val took 23.846s
Iter 1: Train loss 0.000, Total rewards mean 1.428, Total rewards std 1.428, Grouped rewards mean 1.428, Grouped rewards std 1.428, KL 0.000, r1_accuracy_reward_func mean 1.000, r1_accuracy_reward_func std 1.000, r1_int_reward_func mean 0.250, r1_int_reward_func std 0.250, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean 0.178, r1_count_xml std 0.178, Learning Rate 1.000e-05, It/sec 0.021, Tokens/sec 17.021, Peak mem 6.777 GB
Iter 2: Train loss -0.000, Total rewards mean 1.606, Total rewards std 1.606, Grouped rewards mean 1.606, Grouped rewards std 1.606, KL 0.000, r1_accuracy_reward_func mean 1.000, r1_accuracy_reward_func std 1.000, r1_int_reward_func mean 0.250, r1_int_reward_func std 0.250, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean 0.356, r1_count_xml std 0.356, Learning Rate 1.000e-05, It/sec 0.016, Tokens/sec 15.136, Peak mem 7.429 GB
Iter 3: Train loss 0.000, Total rewards mean 1.784, Total rewards std 1.784, Grouped rewards mean 1.784, Grouped rewards std 1.784, KL 0.000, r1_accuracy_reward_func mean 1.000, r1_accuracy_reward_func std 1.000, r1_int_reward_func mean 0.250, r1_int_reward_func std 0.250, r1_strict_format_reward_func mean 0.000, r1_strict_format_reward_func std 0.000, r1_soft_format_reward_func mean 0.000, r1_soft_format_reward_func std 0.000, r1_count_xml mean 0.534, r1_count_xml std 0.534, Learning Rate 1.000e-05, It/sec 0.017, Tokens/sec 15.119, Peak mem 7.553 GB

@Goekdeniz-Guelmez
Copy link
Contributor Author

Using the same dataset with Owen we see that it generates the correct format with the correct tokens:

<think> We can find the total number of pencils that Arnel has by first adding the number of pencils the friends received to the number of pencils Arnel kept, then dividing the total by the number of people whose pencils we need to take into account (Arnel and his five friends). To avoid confusion in the process of division by 6, we will find (8 * 6) + 10 to get the total number of pencils which will also give us the number of pencils per box. </think> <answer>(8 * 6) + 10 </answer>

==========


✅ Answer:
5

==========


🔍 Extracted Answer:
(8 * 6) + 10

===================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.