Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to make the reward function customizable in GRPO #2433

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions recipes/configs/dev/3B_full_grpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ dataset:
seed: null
shuffle: False

# Reward functions
reward_fn:
_component_: torchtune.dev.grpo.rewards.batch_shaped_correctness_reward

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_3b
Expand Down
27 changes: 20 additions & 7 deletions recipes/dev/grpo_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchtune.config._utils import _get_component_from_path
from torchtune.datasets import ConcatDataset
from torchtune.dev.grpo.generation import generate
from torchtune.dev.grpo.rewards import batch_shaped_correctness_reward
from torchtune.dev.grpo import DEFAULT_REWARD_FN
from torchtune.dev.grpo.types import GRPOStats, GRPOTrajectory
from torchtune.modules import local_kv_cache
from torchtune.recipe_interfaces import FTRecipeInterface
Expand Down Expand Up @@ -226,6 +226,8 @@ def setup(self, cfg: DictConfig) -> None:
)
self._tokenizer = config.instantiate(cfg.tokenizer)

self.reward_fn = self._setup_reward_fn(cfg.get("reward_fn", None))

self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
Expand Down Expand Up @@ -596,6 +598,14 @@ def _setup_data(

return sampler, dataloader

def _setup_reward_fn(self, cfg_reward_fn: DictConfig) -> nn.Module:
"""
Setup the reward function.
"""
if cfg_reward_fn is None:
return DEFAULT_REWARD_FN
return config.instantiate(cfg_reward_fn)

def save_checkpoint(
self,
epoch: int,
Expand Down Expand Up @@ -681,7 +691,7 @@ def save_checkpoint(
torch.distributed.barrier()

def generate_trajectory(
self, input_ids: torch.Tensor, answers: List[str]
self, input_ids: torch.Tensor, answers: List[str], **kwargs
) -> GRPOTrajectory:
"""
Generates a trajectory given the current policy model, the reference policy model, the reward function,
Expand Down Expand Up @@ -780,8 +790,11 @@ def generate_trajectory(
# responses :: [B x G, L]
responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L]

rewards, successes = batch_shaped_correctness_reward(
self._tokenizer, responses, answers
rewards, successes = self.reward_fn(
tokenizer=self._tokenizer,
completions=responses,
answers=answers,
**kwargs
) # [B, G]
rewards = rewards.to(self._device)
successes = successes.to(self._device)
Expand Down Expand Up @@ -814,7 +827,7 @@ def generate_trajectory(
)

def generate_trajectory_batched(
self, input_ids: torch.Tensor, answers: List[str]
self, input_ids: torch.Tensor, answers: List[str], **kwargs
) -> GRPOTrajectory:
"""
Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes.
Expand All @@ -839,7 +852,7 @@ def generate_trajectory_batched(
]
torch.cuda.empty_cache()
trajectories.append(
self.generate_trajectory(batch_input_ids, batch_answers)
self.generate_trajectory(batch_input_ids, batch_answers, **kwargs)
)
torch.cuda.empty_cache()
return GRPOTrajectory(*map(torch.cat, zip(*trajectories)))
Expand Down Expand Up @@ -952,7 +965,7 @@ def train(self) -> None:

_, context_length = tokens.shape

trajectory = self.generate_trajectory_batched(tokens, answers)
trajectory = self.generate_trajectory_batched(**batch)
torch.distributed.barrier()

grpo_stats: list[GRPOStats] = []
Expand Down
4 changes: 4 additions & 0 deletions torchtune/dev/grpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .rewards import batch_shaped_correctness_reward

DEFAULT_REWARD_FN = batch_shaped_correctness_reward
2 changes: 1 addition & 1 deletion torchtune/dev/grpo/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def shaped_correctness_reward(answer: str, completion: str) -> tuple[float, floa

def batch_shaped_correctness_reward(
tokenizer: ModelTokenizer, completions: torch.Tensor, answers: list[str]
) -> [torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Utility function to apply the shaped reward function to a GRPO-style batch of completions."""

batch_size, grpo_size, *_ = completions.shape
Expand Down
Loading