From d42f319bc34579e6a04c7eeb4d01dfe310bb02ab Mon Sep 17 00:00:00 2001 From: Rylan Conway Date: Thu, 7 Nov 2024 20:11:34 -0800 Subject: [PATCH] Update torchtune generation to be more flexible Summary: The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer). This is an unnecessary limitation and is easy to fix if we simply create the `q` tensor to match the size of the logits tensor instead of the embedding layer. NOTE: this is just a draft diff to get feedback on possible changes to the OSS torchtune package before submitting a proper pull request Differential Revision: D65480353 --- torchtune/generation/_generation.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py index c2d60a7373..207d40aa8e 100644 --- a/torchtune/generation/_generation.py +++ b/torchtune/generation/_generation.py @@ -67,7 +67,7 @@ def generate_next_token( model: TransformerDecoder, input_pos: torch.Tensor, x: torch.Tensor, - q: torch.Tensor, + q: Optional[torch.Tensor] = None, *, mask: Optional[torch.Tensor] = None, temperature: float = 1.0, @@ -302,9 +302,11 @@ def generate( # tensors are of identical shape to the prompt curr_masks = masks[:, :prompt_length, :prompt_length] - q = torch.empty( - (bsz, model.tok_embeddings.num_embeddings), device=prompt.device - ).exponential_(1, generator=rng) + q = None + if rng is not None: + q = torch.empty( + (bsz, model.tok_embeddings.num_embeddings), device=prompt.device + ).exponential_(1, generator=rng) tokens, generated_logits = generate_next_token( model, input_pos=input_pos[:, :prompt_length].squeeze(), @@ -360,9 +362,11 @@ def generate( curr_input_pos = input_pos[:, : curr_pos + 1] curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] - q = torch.empty( - (bsz, model.tok_embeddings.num_embeddings), device=prompt.device - ).exponential_(1, generator=rng) + q = None + if rng is not None: + q = torch.empty( + (bsz, model.tok_embeddings.num_embeddings), device=prompt.device + ).exponential_(1, generator=rng) tokens, logits = custom_generate_next_token( model, input_pos=curr_input_pos,