Skip to content

Commit

Permalink
Merge branch 'master' into regression_metrics_patch
Browse files Browse the repository at this point in the history
  • Loading branch information
Ido Amos [email protected] committed Nov 20, 2024
2 parents b0dd20e + a156da0 commit 2d79ed1
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 34 deletions.
102 changes: 69 additions & 33 deletions fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,51 @@
import random
from torch import Tensor
from typing import Dict
from collections.abc import Iterable
from tokenizers import Tokenizer, Encoding
import tokenizers
from warnings import warn
from typing import Optional, List, Set, Union, Tuple, Any, Iterator
from typing import Optional, List, Set, Union, Tuple, Any, Iterator, Dict
import json
import transformers
import os
from omegaconf import OmegaConf
import collections
import omegaconf
import copy
import traceback
import re
from fuse.data.tokenizers.modular_tokenizer.special_tokens import special_wrap_input
from dataclasses import dataclass


TypedInput = collections.namedtuple(
"TypedInput", ["input_type", "input_string", "max_len"]
)
@dataclass
class ModularTokenizerInput:
input_type: str # sub tokenizer name
input_string: str # the string to tokenize
max_len: Optional[int] = None # max length used for truncation only.
truncate_mode: Optional[
str
] = None # by defualt will truncate with right direction, setting to "RAND" will randomly crop a sub-sequence of length=max_len


def list_to_tokenizer_string(lst: List[TypedInput]) -> str:
# for backward compatibility
TypedInput = ModularTokenizerInput


def list_to_tokenizer_string(lst: List[ModularTokenizerInput]) -> str:
out = ""
# prev_tokenizer = None
for in_named_tuple in lst:
curr_tokenizer = in_named_tuple.input_type
curr_len = in_named_tuple.max_len
truncate_mode = in_named_tuple.truncate_mode
# NOTE: For now we don't combine consequent strings encoded by the same tokenizer,
# since they may have different max lengths, so we create a new entry, even if curr_tokenizer == prev_tokenizer:
if curr_len is None:
out += f"<@TOKENIZER-TYPE={curr_tokenizer}>"
else:
elif truncate_mode is None:
out += f"<@TOKENIZER-TYPE={curr_tokenizer}@MAX-LEN={curr_len}>"
else:
out += f"<@TOKENIZER-TYPE={curr_tokenizer}@MAX-LEN={curr_len}@TRUNC-MODE={truncate_mode}>"
out += in_named_tuple.input_string
# prev_tokenizer = curr_tokenizer
return out
Expand All @@ -47,6 +59,7 @@ def __init__(
special_tokens_list: Optional[list] = None,
max_possible_token_id: Optional[int] = None,
max_special_token_id: Optional[int] = None,
seed: int = 1234,
**kwargs: Any,
) -> None:
"""Creates a modular tokenizer that combines multiple existing tokenizers, adjusting them so that:
Expand Down Expand Up @@ -74,6 +87,7 @@ def __init__(
If max_special_token_id is set, when special tokens are added, they are mapped to IDs between 0 and max_special_token_id
(after which come regular token IDs). Once max_special_token_id is reached, no more special tokens may be added.
If it is not set, new special tokens may be mapped to IDs higher that regular token IDs. If Defaults to None (i.e. no limit is set).
seed: random generator seed - used for random truncation in random truncation mode (not the default mode)
"""
# ModularTokenizer inherits the interface of PreTrainedTokenizerBase, but not the underlying logic, therefore super.__init__() is not called

Expand Down Expand Up @@ -193,6 +207,10 @@ def __init__(
f"tokenizer remapping resulted in IDs greater (max_id={self._get_max_mapped_id()}) than max_possible_id ({self._max_possible_token_id}). Reinitialize the modular tokenizer with larger max_possible_id"
)

# initialize random generator
self._rng = random.Random()
self._rng.seed(seed)

@staticmethod
def remap_vocab(
vocab: Dict,
Expand Down Expand Up @@ -973,7 +991,7 @@ def count_unknowns(

def encode_list(
self,
typed_input_list: List,
typed_input_list: List[ModularTokenizerInput],
max_len: Optional[int] = None,
padding_token_id: Optional[int] = None,
padding_token: Optional[str] = "<PAD>",
Expand All @@ -991,13 +1009,7 @@ def encode_list(
"""_summary_
Args:
typed_input_list (List): list of collections.namedtuple("input_type", ["input_string", "max_len"]), with
input type: the name of input type,
input_string: the string to be encoded
max_len: maximal length of the encoding (in tokens). Only relevant for truncation, as we do not need to
pad individual sub-tokenizer encodings - we only pad the final encoding of the ModularTokenizer.
The smallest value between config-defined and tuple-defined is used. If None, the max_len
that was defined for the sub-tokenizer in the config is used.
typed_input_list (List): list of ModularTokenizerInput.
max_len (Optional[int], optional): _description_. Defaults to None.
padding_token_id (Optional[str], optional): _description_. Defaults to 0. TODO: default to None and infer it
padding_token (Optional[str], optional): _description_. Defaults to "<PAD>".
Expand All @@ -1021,6 +1033,7 @@ def encode_list(
input_type = inpt.input_type
data_str = inpt.input_string
sub_max_len = inpt.max_len
sub_trunc_mode = inpt.truncate_mode
sub_encoding = self._encode_single_type(
data_str=data_str,
input_type=input_type,
Expand All @@ -1030,7 +1043,21 @@ def encode_list(
if sub_max_len is not None:
if len(sub_encoding) > sub_max_len:
overflow_info += f"{input_type}:{len(sub_encoding)}=>{sub_max_len}|"
sub_encoding.truncate(max_length=sub_max_len)
if sub_trunc_mode is None or sub_trunc_mode == "RIGHT":
sub_encoding.truncate(max_length=sub_max_len)
elif sub_trunc_mode == "RAND":
left_truncate = self._rng.randint(
sub_max_len, len(sub_encoding)
)
sub_encoding.truncate(
max_length=left_truncate, direction="left"
)
sub_encoding.truncate(max_length=sub_max_len, direction="right")
assert len(sub_encoding) == sub_max_len
else:
raise Exception(
f"Error: unsupported truncate mode: {sub_trunc_mode}"
)
encoded_list.append(sub_encoding)
sequence_ids.extend([curr_sequence_id] * len(sub_encoding))
sequence_types.extend([input_type] * len(sub_encoding))
Expand Down Expand Up @@ -1128,7 +1155,6 @@ def encode_list(
raise ValueError(
f"Unexpected on_unknown value {on_unknown}. Should be 'warn' or 'raise'"
)

if (not return_overflow_info) and (not also_return_split):
return merged_encoding
ans = [merged_encoding]
Expand Down Expand Up @@ -1214,23 +1240,33 @@ def encode(
), f"Error: expecting leading modular tokenizer hints followed by a sequence to tokenize, got {sequence}"
# arrange as a list of TypedInput - each one will include the type and the following sequence
encode_list_format = []
for tokenizer_type, subseq in zip(
for tokenizer_hints, subseq in zip(
hints_and_subseq[::2], hints_and_subseq[1::2]
):
max_len_str = "@MAX-LEN="
curr_max_len_idx = tokenizer_type.find(max_len_str)
if curr_max_len_idx > 0:
curr_max_len = tokenizer_type[curr_max_len_idx + len(max_len_str) :]
try:
curr_max_len = int(curr_max_len)
except:
raise Exception(
f"Had a problem casting curr_max_len={curr_max_len} to int! it was found inside modular tokenizer meta TOKENIZER-TYPE={tokenizer_type}"
)
tokenizer_type = tokenizer_type[:curr_max_len_idx]
else:
curr_max_len = None
encode_list_format.append(TypedInput(tokenizer_type, subseq, curr_max_len))
tokenizer_hints_parts = tokenizer_hints.split("@")
tokenizer_type = tokenizer_hints_parts[0]
curr_max_len = None
curr_truncate_mode = None
for part in tokenizer_hints_parts[1:]:
if part.startswith("MAX-LEN="):
try:
curr_max_len = int(part[len("MAX-LEN=") :])
except:
raise Exception(
f"Had a problem casting curr_max_len={part} to int! it was found inside modular tokenizer meta TOKENIZER-TYPE={tokenizer_hints}"
)
elif part.startswith("TRUNC-MODE="):
curr_truncate_mode = part[len("TRUNC-MODE=") :]
assert curr_truncate_mode in [
"RIGHT",
"RAND",
], f"Error: unsupported modular tokenizer truncate mode {curr_truncate_mode}, it was found inside modular tokenizer meta TOKENIZER-TYPE={tokenizer_hints}"
else:
raise Exception(f"Error, unsupported meta tokenizer hint: {part}")

encode_list_format.append(
TypedInput(tokenizer_type, subseq, curr_max_len, curr_truncate_mode)
)

return self.encode_list(
typed_input_list=encode_list_format,
Expand Down
4 changes: 4 additions & 0 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __call__(
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos
additional_caller_info_text (Optional[str]): information about the caller to add to error messages
key_out_encoding_per_meta: optional key out. If set to a string will put in it the per-meta-instruction encoded parts as a list of Encoding elements
Raises:
Expand Down Expand Up @@ -446,6 +447,7 @@ def __call__(
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
key_out_scalars: Optional[str] = None,
additional_caller_info_text: Optional[str] = "",
) -> NDict:
"""_summary_
Expand All @@ -468,6 +470,7 @@ def __call__(
if provided, will write to:
`sample_dict[f'{key_out_scalars}.values]` - a 1D torch tensor with all the scalars values
`sample_dict[f'{key_out_scalars}.valid_mask]` - a 1D torch boolean tensor representing which elements have scalar values
additional_caller_info_text (Optional[str]): information about the caller to add to error messages
Returns:
NDict: _description_
Expand All @@ -492,6 +495,7 @@ def __call__(
on_unknown=on_unknown,
verbose=verbose,
validate_ends_with_eos=validate_ends_with_eos,
additional_caller_info_text=additional_caller_info_text,
key_out_encoding_per_meta=key_in
+ ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional
)
Expand Down
151 changes: 150 additions & 1 deletion fuse/dl/models/heads/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
"""

from typing import Optional, Sequence
from typing import Optional, Sequence, List
import torch.nn as nn
from torch import Tensor
import torch


class ClassifierFCN(nn.Module):
Expand Down Expand Up @@ -150,3 +151,151 @@ def __init__(
def forward(self, x: Tensor) -> Tensor:
x = self.classifier(x)
return x


class EncoderEmbeddingOutputHead(nn.Module):
def __init__(
self,
embedding_size: int,
layers: List[int],
dropout: float,
num_classes: int,
pooling: str = None,
):
"""
NOTE: This is work in progress. Do not use for now.
This class applies a multi-layer MLP to an input and allows to apply a pooling operation to the sequence dimension - prior to applying the MLP.
This is usefull for extracting a single representation for embeddings of an entire sequence.
Args:
embedding_size: MLP input dimension.
layers: List[int], specifies the output dimension of the MLP in each layer.
dropout: dropout rate, applied to every layer in the MLP
pooling: str (optional) type of pooling to be used, currently available are ["mean", "last"]. Pooling operations ignore pad tokens - a padding mask should be supplied in the forward pass.
"""
super().__init__()
self.embedding_size = embedding_size
self.layers = layers
self.dropout = dropout
self.pooling_type = pooling

# this weird assignment is for backward compatability
# when loading pretrained weights
self.classifier = ClassifierMLP(
in_ch=embedding_size,
layers_description=layers,
dropout_rate=dropout,
num_classes=num_classes,
).classifier

if pooling is not None:
self.pooling = ModularPooling1D(pooling=pooling)
else:
self.pooling = None

def forward(
self,
inputs: Tensor,
padding_mask: Tensor = None,
keep_pool_dim: bool = True,
) -> Tensor:
"""
Args:
padding_mask: a mask that indicates which positions are for valid tokens (1) and which are padding tokens (0) - typically this is similar to an attention mask.
keep_pool_dim: if True an output of shape (B, L, D) will be returned as (B, 1, D) otherwise returns (B, D)
"""

if self.pooling is not None:
assert (
padding_mask is not None
), "OutputHead attempts to perform pooling - requires the padding_mask to detect padding tokens (usually same as the attention mask to the decoder), but padding_mask is None"

inputs = self.pooling(
inputs=inputs, padding_mask=padding_mask, keep_dim=keep_pool_dim
)

y = self.classifier(inputs)
return y


class ModularPooling1D(nn.Module):
"""
A wrapper around multiple pooling methods.
Args:
pooling: str, type of pooling to apply, available methods are: ["mean", "last"] TODO: add max?
pool_dim: dimension to apply pooling
"""

def __init__(self, pooling: str, pool_dim: int = 1, **kwargs: dict):
super().__init__()

self.pooling_type = pooling
self.pool_dim = pool_dim

if pooling in ["mean", "avg"]: # pools the mean value of none-pad elements

def _mean_pool(inputs: Tensor, last_valid_indices: Tensor) -> Tensor:
inputs = inputs.cumsum(dim=self.pool_dim)
outputs = self._extract_indices(
inputs, last_valid_indices, dim=self.pool_dim
)
outputs = outputs / (last_valid_indices + 1)
return outputs

self.pooling = lambda inputs, indices: _mean_pool(inputs, indices)

elif pooling == "last": # pools the last element that is not a PAD value

def _last_pool(inputs: Tensor, last_valid_indices: Tensor) -> Tensor:
return self._extract_indices(
inputs, last_valid_indices, dim=self.pool_dim
)

self.pooling = lambda inputs, indices: _last_pool(inputs, indices)

else:
raise NotImplementedError

def _extract_indices(self, inputs: Tensor, indices: Tensor, dim: int = 1) -> Tensor:
assert (
dim == 1
), "extract indices for pooling head not implemented for dim != 1 yet"
# extract indices in dimension using diffrentiable ops
indices = indices.reshape(-1)
index = indices.unsqueeze(1).unsqueeze(1)
index = index.expand(size=(index.shape[0], 1, inputs.shape[-1]))
pooled = torch.gather(inputs, dim=dim, index=index).squeeze(1)
return pooled

def forward(
self,
inputs: Tensor,
padding_mask: Tensor = None,
keep_dim: bool = True,
) -> Tensor:
"""
See OutputHead().forward for a detailed description.
"""
if padding_mask.dtype != torch.bool:
padding_mask = padding_mask.to(torch.bool)
# get indices of last positions of no-pad tokens
last_valid_indices = get_last_non_pad_token(
padding_mask=padding_mask
).unsqueeze(1)
out = self.pooling(inputs, last_valid_indices)
if keep_dim:
out = out.unsqueeze(self.pool_dim)
return out


def get_last_non_pad_token(padding_mask: Tensor) -> Tensor:
"""
Returns the positions of last non-pad token, for every element in the batch.
Expected input shape is (B, L), B is the batch size, L is the sequence dimension.
Args:
padding_mask: a boolean tensor with True values for none-padded positions and False values for padded positions (usually same as the attention mask input to an encoder model)
"""
non_pad_pos = padding_mask.cumsum(dim=-1) # starts from 1
non_pad_last_pos = non_pad_pos[:, -1] - 1

return non_pad_last_pos

0 comments on commit 2d79ed1

Please sign in to comment.