Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding grpo training #1233

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
5e0ae83
initial commit, gn
Goekdeniz-Guelmez Jan 28, 2025
b1e573d
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Jan 29, 2025
93370ff
updates ans fixing the KL div lines
Goekdeniz-Guelmez Jan 30, 2025
6c58aa9
updates
Goekdeniz-Guelmez Jan 31, 2025
80bcf68
grpo_trainer shoudl be done
Goekdeniz-Guelmez Jan 31, 2025
a57d553
update
Goekdeniz-Guelmez Jan 31, 2025
243c962
update lora.py
Goekdeniz-Guelmez Jan 31, 2025
d034ca3
adding function for R1
Goekdeniz-Guelmez Feb 3, 2025
734d6f4
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 3, 2025
a3ed632
dataset wrapper done
Goekdeniz-Guelmez Feb 3, 2025
41ff536
Merge branch 'adding-GRPO-training' of https://github.com/Goekdeniz-G…
Goekdeniz-Guelmez Feb 3, 2025
23d75cd
starting fist training test run
Goekdeniz-Guelmez Feb 3, 2025
1d9e480
first working prototype, will try training out at home
Goekdeniz-Guelmez Feb 3, 2025
05d921b
optims
Goekdeniz-Guelmez Feb 3, 2025
40bca77
fixes
Goekdeniz-Guelmez Feb 3, 2025
06f9c29
print func name
Goekdeniz-Guelmez Feb 3, 2025
54e295e
fix name funcs
Goekdeniz-Guelmez Feb 3, 2025
ca32424
updates
Goekdeniz-Guelmez Feb 3, 2025
7173840
first succesfull training run
Goekdeniz-Guelmez Feb 4, 2025
bd1a42e
adding args into dataset handling
Goekdeniz-Guelmez Feb 4, 2025
7b01414
better create_dataset
Goekdeniz-Guelmez Feb 4, 2025
0a09a93
fix cache handling
Goekdeniz-Guelmez Feb 5, 2025
2a8e6f6
udpate
Goekdeniz-Guelmez Feb 5, 2025
d84ad0c
fix testing
Goekdeniz-Guelmez Feb 5, 2025
a33cad8
udpates
Goekdeniz-Guelmez Feb 5, 2025
35a2d99
smoll fix
Goekdeniz-Guelmez Feb 5, 2025
0a19522
updates
Goekdeniz-Guelmez Feb 5, 2025
bcfa55d
updates
Goekdeniz-Guelmez Feb 5, 2025
94dcd0f
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 6, 2025
9ba6146
fix
Goekdeniz-Guelmez Feb 9, 2025
39e9469
freeze ref model
Goekdeniz-Guelmez Feb 9, 2025
5417990
fix
Goekdeniz-Guelmez Feb 9, 2025
a527cdb
fix: prevent gradients from flowing through the reference model's logits
Goekdeniz-Guelmez Feb 9, 2025
0071252
rebase loss calculation
Goekdeniz-Guelmez Feb 9, 2025
0dac286
Merge branch 'main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 10, 2025
d9da35f
nits
Goekdeniz-Guelmez Feb 10, 2025
f88e897
removing helper functions
Goekdeniz-Guelmez Feb 10, 2025
e5aa2c3
nits
Goekdeniz-Guelmez Feb 10, 2025
b7bc811
nits
Goekdeniz-Guelmez Feb 10, 2025
88ca747
nits
Goekdeniz-Guelmez Feb 10, 2025
e96afe9
updates
Goekdeniz-Guelmez Feb 11, 2025
e80bf95
fix
Goekdeniz-Guelmez Feb 11, 2025
35ecc17
fix
Goekdeniz-Guelmez Feb 11, 2025
978deab
small fix
Goekdeniz-Guelmez Feb 11, 2025
5aeefc8
update new iterade batches function + nits
Goekdeniz-Guelmez Feb 12, 2025
c42e858
Merge branch 'adding-GRPO-training' of https://github.com/Goekdeniz-G…
Goekdeniz-Guelmez Feb 12, 2025
e33d9d5
updates
Goekdeniz-Guelmez Feb 12, 2025
3823154
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 12, 2025
a7273f6
small fix
Goekdeniz-Guelmez Feb 12, 2025
8179b99
quick prompting fix
Goekdeniz-Guelmez Feb 12, 2025
65a49dd
nits
Goekdeniz-Guelmez Feb 13, 2025
baeb9f1
reduncancy fix + nits
Goekdeniz-Guelmez Feb 14, 2025
5ec4790
removing comments + adding temperature + reward weighting
Goekdeniz-Guelmez Feb 15, 2025
6a6bd53
removing print and switching some variables in the math
Goekdeniz-Guelmez Feb 15, 2025
1eea135
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 17, 2025
541f0be
fix generation cutoff in evaluation
Goekdeniz-Guelmez Feb 17, 2025
11c8991
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 19, 2025
2f20107
little faster generation + prints ot a examplke generatino in validat…
Goekdeniz-Guelmez Feb 21, 2025
6086137
Huge speed improvement in validation mode.
Goekdeniz-Guelmez Feb 21, 2025
710bc14
training mode working too got from 2 toks/sec to 30 toks/sec with raw…
Goekdeniz-Guelmez Feb 21, 2025
c51b0a2
fix
Goekdeniz-Guelmez Feb 21, 2025
79de353
nits
Goekdeniz-Guelmez Feb 22, 2025
235348c
generation speed improvement in training too from 3 t/s to 15 t/s
Goekdeniz-Guelmez Feb 22, 2025
d653371
nits
Goekdeniz-Guelmez Feb 22, 2025
d9c4c6e
clean up and readding temperature argument
Goekdeniz-Guelmez Feb 22, 2025
9705ed9
fix wrong generation in train
Goekdeniz-Guelmez Feb 22, 2025
c0bd89a
add usage in LORA.md
Goekdeniz-Guelmez Feb 22, 2025
bd5f081
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 22, 2025
e4eac9c
adding custom system message integration in dataset, more opimization…
Goekdeniz-Guelmez Feb 24, 2025
53185c7
last update, gn
Goekdeniz-Guelmez Feb 24, 2025
ef6ff92
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 25, 2025
fab2dc2
smoll fix
Goekdeniz-Guelmez Feb 26, 2025
f27ed26
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 27, 2025
a04eb02
Merge branch 'ml-explore:main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 28, 2025
15d5327
batching fix
Goekdeniz-Guelmez Feb 28, 2025
80e10a5
Merge branch 'main' into adding-GRPO-training
Goekdeniz-Guelmez Feb 28, 2025
925e114
updates
Goekdeniz-Guelmez Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 145 additions & 20 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
Expand Down Expand Up @@ -42,6 +43,7 @@
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"training_mode": "normal",
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
Expand All @@ -62,6 +64,15 @@
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},

# GRPO args
"reference_model_path": None,
"group_size": 4,
"beta": 0.1,
"epsilon": 1e-4,
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
}


Expand Down Expand Up @@ -94,6 +105,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "grpo"],
help="Training mode: normal or GRPO",
)
parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -161,6 +178,44 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")

# GRPO args
parser.add_argument(
"--group-size",
type=int,
help="Number of generations.",
default=4,
)
parser.add_argument(
"--max-completion-length",
type=int,
help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.",
default=512,
)
parser.add_argument(
"--beta",
type=float,
help="KL penalty coefficient.",
default=0.1,
)
parser.add_argument(
"--epsilon",
type=float,
help="The Epsilon for numerical stability.",
default=1e-4,
)
parser.add_argument(
"--use-chat-template",
action="store_true",
help="If the model is a Chat model, use the Chat template.",
default=None,
)
parser.add_argument(
"--use-prompt",
action="store_true",
help="Rather to use the prompt from the R1 paper.",
default=None,
)
return parser


Expand Down Expand Up @@ -220,32 +275,102 @@ def train_model(
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.training_mode == "grpo":
training_args = GRPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
max_completion_length=args.max_completion_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
reference_model_path=args.reference_model_path
)

if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
reference_model = reference_model.freeze()
else:
reference_model, _ = load(args.model)

train_grpo(
model=model,
ref_model=reference_model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)

train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)


def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()

test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
if args.training_mode == "grpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model

test_loss, _, test_rewards = evaluate_grpo(
model=model,
ref_model=reference_model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon
)

test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

test_ppl = math.exp(test_loss)
test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand Down Expand Up @@ -296,4 +421,4 @@ def main():


if __name__ == "__main__":
main()
main()
91 changes: 77 additions & 14 deletions llms/mlx_lm/tuner/datasets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,59 @@
import json
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

from transformers import PreTrainedTokenizer


class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
answer_key: str = "answer",
use_chat_template: bool = False,
use_prompt: bool = False
):
self._data = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str}
],
)
answer_tokens = tokenizer.encode(answer_str)
else:
if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str}. Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))

def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]

def __len__(self) -> int:
"""Returns the number of examples in the dataset."""
return len(self._data)


class Dataset:
"""
Light-weight wrapper to hold a dataset.
Expand Down Expand Up @@ -82,6 +131,7 @@ def __len__(self):


def create_dataset(
args,
data,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
Expand All @@ -90,20 +140,32 @@ def create_dataset(
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data, tokenizer)

if args.training_mode == "normal":
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample:
return Dataset(data, tokenizer)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
return GRPODataset(
data=data,
tokenizer=tokenizer,
prompt_key="prompt",
answer_key="answer",
use_chat_template=args.use_chat_template,
use_prompt=args.use_prompt
)


def load_local_dataset(
args,
data_path: Path,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
Expand All @@ -114,14 +176,15 @@ def load_subset(path):
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, prompt_feature, completion_feature)
return create_dataset(args, data, tokenizer, prompt_feature, completion_feature)

names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test


def load_hf_dataset(
args,
data_id: str,
tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None,
Expand All @@ -137,7 +200,7 @@ def load_hf_dataset(
train, valid, test = [
(
create_dataset(
dataset[n], tokenizer, prompt_feature, completion_feature
args, dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys()
else []
Expand Down Expand Up @@ -202,12 +265,12 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists():
train, valid, test = load_local_dataset(
data_path, tokenizer, prompt_feature, completion_feature
args, data_path, tokenizer, prompt_feature, completion_feature
)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(
args.data, tokenizer, prompt_feature, completion_feature
args, args.data, tokenizer, prompt_feature, completion_feature
)

if args.train and len(train) == 0:
Expand Down
Loading