From ab666131d39383845994e1a1bc9cdfcddfe602dd Mon Sep 17 00:00:00 2001 From: Ben Shapira <66549915+bensha6757@users.noreply.github.com> Date: Wed, 22 Jan 2025 12:06:12 +0200 Subject: [PATCH] Add EmbeddingInfo NamedTuple (#392) Co-authored-by: Ben Shapira --- .../tokenizers/modular_tokenizer/inject_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index b595e8b1..b84f3902 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -1,6 +1,6 @@ import re from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, NamedTuple, Optional, Tuple, Union from warnings import warn import torch @@ -13,6 +13,11 @@ from fuse.utils import NDict +class EmbeddingInfo(NamedTuple): + location: int + embedding_input: str + + class InjectorToModularTokenizerLib: """ InjectorTokenizer builds on top of ModularTokenizer. @@ -227,7 +232,10 @@ def build_scalars_and_embeddings( embedding_input = sample_dict[embeddings_key] external_embeddings_info[embedding_model_name].append( - (num_tokens_token_so_far, embedding_input) + EmbeddingInfo( + location=num_tokens_token_so_far, + embedding_input=embedding_input, + ) ) elif tokenizer_name.startswith("VECTORS_"): raise NotImplementedError