Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding grpo training #1233

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
5e0ae83
initial commit, gn
Goekdeniz-Guelmez Jan 28, 2025
b1e573d
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Jan 29, 2025
93370ff
updates ans fixing the KL div lines
Goekdeniz-Guelmez Jan 30, 2025
6c58aa9
updates
Goekdeniz-Guelmez Jan 31, 2025
80bcf68
grpo_trainer shoudl be done
Goekdeniz-Guelmez Jan 31, 2025
a57d553
update
Goekdeniz-Guelmez Jan 31, 2025
243c962
update lora.py
Goekdeniz-Guelmez Jan 31, 2025
d034ca3
adding function for R1
Goekdeniz-Guelmez Feb 3, 2025
734d6f4
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 3, 2025
a3ed632
dataset wrapper done
Goekdeniz-Guelmez Feb 3, 2025
41ff536
Merge branch 'adding-GRPO-training' of https://github.com/Goekdeniz-G…
Goekdeniz-Guelmez Feb 3, 2025
23d75cd
starting fist training test run
Goekdeniz-Guelmez Feb 3, 2025
1d9e480
first working prototype, will try training out at home
Goekdeniz-Guelmez Feb 3, 2025
05d921b
optims
Goekdeniz-Guelmez Feb 3, 2025
40bca77
fixes
Goekdeniz-Guelmez Feb 3, 2025
06f9c29
print func name
Goekdeniz-Guelmez Feb 3, 2025
54e295e
fix name funcs
Goekdeniz-Guelmez Feb 3, 2025
ca32424
updates
Goekdeniz-Guelmez Feb 3, 2025
7173840
first succesfull training run
Goekdeniz-Guelmez Feb 4, 2025
bd1a42e
adding args into dataset handling
Goekdeniz-Guelmez Feb 4, 2025
7b01414
better create_dataset
Goekdeniz-Guelmez Feb 4, 2025
0a09a93
fix cache handling
Goekdeniz-Guelmez Feb 5, 2025
2a8e6f6
udpate
Goekdeniz-Guelmez Feb 5, 2025
d84ad0c
fix testing
Goekdeniz-Guelmez Feb 5, 2025
a33cad8
udpates
Goekdeniz-Guelmez Feb 5, 2025
35a2d99
smoll fix
Goekdeniz-Guelmez Feb 5, 2025
0a19522
updates
Goekdeniz-Guelmez Feb 5, 2025
bcfa55d
updates
Goekdeniz-Guelmez Feb 5, 2025
94dcd0f
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 6, 2025
9ba6146
fix
Goekdeniz-Guelmez Feb 9, 2025
39e9469
freeze ref model
Goekdeniz-Guelmez Feb 9, 2025
5417990
fix
Goekdeniz-Guelmez Feb 9, 2025
a527cdb
fix: prevent gradients from flowing through the reference model's logits
Goekdeniz-Guelmez Feb 9, 2025
0071252
rebase loss calculation
Goekdeniz-Guelmez Feb 9, 2025
0dac286
Merge branch 'main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 10, 2025
d9da35f
nits
Goekdeniz-Guelmez Feb 10, 2025
f88e897
removing helper functions
Goekdeniz-Guelmez Feb 10, 2025
e5aa2c3
nits
Goekdeniz-Guelmez Feb 10, 2025
b7bc811
nits
Goekdeniz-Guelmez Feb 10, 2025
88ca747
nits
Goekdeniz-Guelmez Feb 10, 2025
e96afe9
updates
Goekdeniz-Guelmez Feb 11, 2025
e80bf95
fix
Goekdeniz-Guelmez Feb 11, 2025
35ecc17
fix
Goekdeniz-Guelmez Feb 11, 2025
978deab
small fix
Goekdeniz-Guelmez Feb 11, 2025
5aeefc8
update new iterade batches function + nits
Goekdeniz-Guelmez Feb 12, 2025
c42e858
Merge branch 'adding-GRPO-training' of https://github.com/Goekdeniz-G…
Goekdeniz-Guelmez Feb 12, 2025
e33d9d5
updates
Goekdeniz-Guelmez Feb 12, 2025
3823154
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 12, 2025
a7273f6
small fix
Goekdeniz-Guelmez Feb 12, 2025
8179b99
quick prompting fix
Goekdeniz-Guelmez Feb 12, 2025
65a49dd
nits
Goekdeniz-Guelmez Feb 13, 2025
baeb9f1
reduncancy fix + nits
Goekdeniz-Guelmez Feb 14, 2025
5ec4790
removing comments + adding temperature + reward weighting
Goekdeniz-Guelmez Feb 15, 2025
6a6bd53
removing print and switching some variables in the math
Goekdeniz-Guelmez Feb 15, 2025
1eea135
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 17, 2025
541f0be
fix generation cutoff in evaluation
Goekdeniz-Guelmez Feb 17, 2025
11c8991
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 19, 2025
2f20107
little faster generation + prints ot a examplke generatino in validat…
Goekdeniz-Guelmez Feb 21, 2025
6086137
Huge speed improvement in validation mode.
Goekdeniz-Guelmez Feb 21, 2025
710bc14
training mode working too got from 2 toks/sec to 30 toks/sec with raw…
Goekdeniz-Guelmez Feb 21, 2025
c51b0a2
fix
Goekdeniz-Guelmez Feb 21, 2025
79de353
nits
Goekdeniz-Guelmez Feb 22, 2025
235348c
generation speed improvement in training too from 3 t/s to 15 t/s
Goekdeniz-Guelmez Feb 22, 2025
d653371
nits
Goekdeniz-Guelmez Feb 22, 2025
d9c4c6e
clean up and readding temperature argument
Goekdeniz-Guelmez Feb 22, 2025
9705ed9
fix wrong generation in train
Goekdeniz-Guelmez Feb 22, 2025
c0bd89a
add usage in LORA.md
Goekdeniz-Guelmez Feb 22, 2025
bd5f081
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 22, 2025
e4eac9c
adding custom system message integration in dataset, more opimization…
Goekdeniz-Guelmez Feb 24, 2025
53185c7
last update, gn
Goekdeniz-Guelmez Feb 24, 2025
ef6ff92
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 25, 2025
fab2dc2
smoll fix
Goekdeniz-Guelmez Feb 26, 2025
f27ed26
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 27, 2025
a04eb02
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 28, 2025
15d5327
batching fix
Goekdeniz-Guelmez Feb 28, 2025
80e10a5
Merge branch 'main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 28, 2025
925e114
updates
Goekdeniz-Guelmez Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:

- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [GRPO](#GRPO)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
Expand Down Expand Up @@ -84,6 +85,33 @@ ignore the prompt and compute loss for just the completion by passing
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.

### Group Relative Policy Optimization (GRPO)

To fine-tune a model using GRPO, which optimizes policy using multiple responses per prompt, use:

```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--data <path_to_data> \
--fine-tune-type grpo \
--group-size 4
```

GRPO specific arguments:

- `--group-size`: Number of responses generated per prompt (default: 4)
- `--beta`: KL penalty coefficient for policy optimization (default: 0.1)
- `--epsilon`: Small constant for numerical stability (default: 1e-4)
- `--max-completion-length`: Maximum length of generated completions (default: 512)
- `--reference-model-path`: Path to reference model weights. If not specified, uses the same model
- `--temperature`: Sampling temperature for generations. Higher values increase randomness (default: 1.0)
- `--reward-weights`: Optional list of weights for multiple reward functions. Must match number of reward functions. If not specified, all rewards weighted equally with 1.0

The GRPO training method generates multiple responses for each prompt and optimizes the policy using relative rewards between responses. This approach helps improve response quality by learning from comparisons between different completions.

Note that GRPO requires more compute resources than standard LoRA training since it generates multiple responses per prompt. Consider reducing batch size or using gradient checkpointing if running into memory issues.

### Evaluate

To compute test set perplexity use:
Expand Down
215 changes: 178 additions & 37 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Copyright © 2024 Apple Inc.

from pathlib import Path
import argparse
import types
import math
import os
import re
import types
from pathlib import Path

import mlx.nn as nn
import mlx.optimizers as optim
import mlx.nn as nn
import numpy as np
import yaml

from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
Expand Down Expand Up @@ -42,6 +43,7 @@
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"training_mode": "normal",
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
Expand All @@ -63,6 +65,17 @@
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"mask_prompt": False,

# GRPO args
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
"temperature": 1.0,
"reward_weights": None
}


Expand Down Expand Up @@ -103,6 +116,12 @@ def build_parser():
default=None,
)

parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "grpo"],
help="Training mode: normal or GRPO",
)
parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -170,8 +189,93 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")

# GRPO args
parser.add_argument(
"--group-size",
type=int,
help="Number of generations.",
default=4,
)
parser.add_argument(
"--max-completion-length",
type=int,
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
default=512,
)
parser.add_argument(
"--beta",
type=float,
help="KL penalty coefficient.",
default=0.1,
)
parser.add_argument(
"--epsilon",
type=float,
help="The Epsilon for numerical stability.",
default=1e-4,
)
parser.add_argument(
"--use-chat-template",
action="store_true",
help="If the model is a Chat model, use the Chat template.",
default=None,
)
parser.add_argument(
"--use-prompt",
action="store_true",
help="Rather to use the prompt from the R1 paper.",
default=None,
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature for sampling. The higher the temperature, the more random the completions.",
default=1.0,
)
parser.add_argument(
"--reward-weights",
type=str,
help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
default=None,
)
return parser

def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path,
temperature=args.temperature,
reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
)

if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)

train_grpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)

def train_model(
args,
Expand Down Expand Up @@ -215,19 +319,6 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)

model.train()
opt = optim.Adam(
learning_rate=(
Expand All @@ -236,32 +327,82 @@ def train_model(
)

# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.training_mode == "grpo":
train_model_grpo(
model,
tokenizer,
args,
opt,
train_set,
valid_set,
adapter_file,
training_callback
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)

train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)


def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()

test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
if args.training_mode == "grpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)

test_loss, _, test_rewards = evaluate_grpo(
model=model,
ref_model=reference_model.freeze(),
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
temperature=args.temperature,
max_tokens=args.max_seq_length
)

test_ppl = math.exp(test_loss)

rewards_str = ", ".join([f"{k}: {v:.3f}" for k, v in test_rewards.items()])
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {rewards_str}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

test_ppl = math.exp(test_loss)
test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand Down Expand Up @@ -312,4 +453,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading