diff --git a/recipes/configs/dev/3B_full_grpo.yaml b/recipes/configs/dev/3B_full_grpo.yaml index 3fd1d7e240..a3a687b88f 100644 --- a/recipes/configs/dev/3B_full_grpo.yaml +++ b/recipes/configs/dev/3B_full_grpo.yaml @@ -38,6 +38,10 @@ dataset: seed: null shuffle: False +# Reward functions +reward_fn: + _component_: torchtune.dev.grpo.rewards.batch_shaped_correctness_reward + # Model Arguments model: _component_: torchtune.models.llama3_2.llama3_2_3b diff --git a/recipes/dev/grpo_full_finetune_distributed.py b/recipes/dev/grpo_full_finetune_distributed.py index 8d1511c362..ade75ccd3a 100644 --- a/recipes/dev/grpo_full_finetune_distributed.py +++ b/recipes/dev/grpo_full_finetune_distributed.py @@ -21,7 +21,7 @@ from torchtune.config._utils import _get_component_from_path from torchtune.datasets import ConcatDataset from torchtune.dev.grpo.generation import generate -from torchtune.dev.grpo.rewards import batch_shaped_correctness_reward +from torchtune.dev.grpo import DEFAULT_REWARD_FN from torchtune.dev.grpo.types import GRPOStats, GRPOTrajectory from torchtune.modules import local_kv_cache from torchtune.recipe_interfaces import FTRecipeInterface @@ -226,6 +226,8 @@ def setup(self, cfg: DictConfig) -> None: ) self._tokenizer = config.instantiate(cfg.tokenizer) + self.reward_fn = self._setup_reward_fn(cfg.get("reward_fn", None)) + self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, opt_state_dict=( @@ -596,6 +598,14 @@ def _setup_data( return sampler, dataloader + def _setup_reward_fn(self, cfg_reward_fn: DictConfig) -> nn.Module: + """ + Setup the reward function. + """ + if cfg_reward_fn is None: + return DEFAULT_REWARD_FN + return config.instantiate(cfg_reward_fn) + def save_checkpoint( self, epoch: int, @@ -681,7 +691,7 @@ def save_checkpoint( torch.distributed.barrier() def generate_trajectory( - self, input_ids: torch.Tensor, answers: List[str] + self, input_ids: torch.Tensor, answers: List[str], **kwargs ) -> GRPOTrajectory: """ Generates a trajectory given the current policy model, the reference policy model, the reward function, @@ -780,8 +790,11 @@ def generate_trajectory( # responses :: [B x G, L] responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L] - rewards, successes = batch_shaped_correctness_reward( - self._tokenizer, responses, answers + rewards, successes = self.reward_fn( + tokenizer=self._tokenizer, + completions=responses, + answers=answers, + **kwargs ) # [B, G] rewards = rewards.to(self._device) successes = successes.to(self._device) @@ -814,7 +827,7 @@ def generate_trajectory( ) def generate_trajectory_batched( - self, input_ids: torch.Tensor, answers: List[str] + self, input_ids: torch.Tensor, answers: List[str], **kwargs ) -> GRPOTrajectory: """ Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. @@ -839,7 +852,7 @@ def generate_trajectory_batched( ] torch.cuda.empty_cache() trajectories.append( - self.generate_trajectory(batch_input_ids, batch_answers) + self.generate_trajectory(batch_input_ids, batch_answers, **kwargs) ) torch.cuda.empty_cache() return GRPOTrajectory(*map(torch.cat, zip(*trajectories))) @@ -952,7 +965,7 @@ def train(self) -> None: _, context_length = tokens.shape - trajectory = self.generate_trajectory_batched(tokens, answers) + trajectory = self.generate_trajectory_batched(**batch) torch.distributed.barrier() grpo_stats: list[GRPOStats] = [] diff --git a/torchtune/dev/grpo/__init__.py b/torchtune/dev/grpo/__init__.py index 2e41cd717f..848719bc6d 100644 --- a/torchtune/dev/grpo/__init__.py +++ b/torchtune/dev/grpo/__init__.py @@ -3,3 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from .rewards import batch_shaped_correctness_reward + +DEFAULT_REWARD_FN = batch_shaped_correctness_reward \ No newline at end of file diff --git a/torchtune/dev/grpo/rewards.py b/torchtune/dev/grpo/rewards.py index 2cba5ee4a4..79ccd52602 100644 --- a/torchtune/dev/grpo/rewards.py +++ b/torchtune/dev/grpo/rewards.py @@ -72,7 +72,7 @@ def shaped_correctness_reward(answer: str, completion: str) -> tuple[float, floa def batch_shaped_correctness_reward( tokenizer: ModelTokenizer, completions: torch.Tensor, answers: list[str] -) -> [torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Utility function to apply the shaped reward function to a GRPO-style batch of completions.""" batch_size, grpo_size, *_ = completions.shape