From 1b36d1f72fc6c1880894f8d7fc878f976abef68c Mon Sep 17 00:00:00 2001 From: Mark Obozov Date: Sat, 22 Feb 2025 23:05:16 +0300 Subject: [PATCH] padding --- recipes/dev/grpo_full_finetune_distributed.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/recipes/dev/grpo_full_finetune_distributed.py b/recipes/dev/grpo_full_finetune_distributed.py index 8d1511c362..101de94445 100644 --- a/recipes/dev/grpo_full_finetune_distributed.py +++ b/recipes/dev/grpo_full_finetune_distributed.py @@ -813,6 +813,18 @@ def generate_trajectory( seq_lens=seq_lens, ) + def _pad_tensor( + self, tensor: torch.Tensor, target_dim: int, pad_value: float, dim: int = 1 + ) -> torch.Tensor: + pad_size = target_dim - tensor.shape[dim] + if pad_size <= 0: + return tensor + + pad = [0] * (2 * tensor.ndim) + pad[-2 * (dim + 1)] = 0 + pad[-2 * (dim + 1) + 1] = pad_size + return torch.nn.functional.pad(tensor, pad, value=pad_value) + def generate_trajectory_batched( self, input_ids: torch.Tensor, answers: List[str] ) -> GRPOTrajectory: @@ -842,7 +854,42 @@ def generate_trajectory_batched( self.generate_trajectory(batch_input_ids, batch_answers) ) torch.cuda.empty_cache() - return GRPOTrajectory(*map(torch.cat, zip(*trajectories))) + + # We need to pad, to ensure that concation will not fail with error + max_p_plus_l = max(t.query_responses.shape[1] for t in trajectories) + max_l = max(t.logprobs.shape[1] for t in trajectories) + + padded_trajectories = [] + for traj in trajectories: + padded_masks = self._pad_tensor( + self._pad_tensor(traj.masks, max_p_plus_l, 0, dim=2), + max_p_plus_l, + 0, + dim=1, + ) + + padded_trajectories.append( + GRPOTrajectory( + query_responses=self._pad_tensor( + traj.query_responses, max_p_plus_l, 1, dim=1 + ), + logprobs=self._pad_tensor(traj.logprobs, max_l, 1.0, dim=1), + ref_logprobs=self._pad_tensor(traj.ref_logprobs, max_l, 1.0, dim=1), + rewards=traj.rewards, + successes=traj.successes, + advantages=traj.advantages, + masks=padded_masks, + position_ids=self._pad_tensor( + traj.position_ids, max_p_plus_l, 0, dim=1 + ), + response_padding_masks=self._pad_tensor( + traj.response_padding_masks, max_l, False, dim=1 + ), + seq_lens=traj.seq_lens, + ) + ) + + return GRPOTrajectory(*map(torch.cat, zip(*padded_trajectories))) def grpo_step( self,