From e569b1075148c96d581864fcea7bb3241c114b02 Mon Sep 17 00:00:00 2001 From: ETZhangSX Date: Sun, 29 Sep 2024 11:41:20 +0800 Subject: [PATCH 1/2] Add support for embedding models: Text2Vec, M3E, GTE --- fastchat/model/model_adapter.py | 101 +++++++++++++++++++------------- 1 file changed, 59 insertions(+), 42 deletions(-) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 92e19dbb7..b2066a5e3 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -138,6 +138,34 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("one_shot") +class BaseEmbeddingModelAdapter(BaseModelAdapter): + """The base embedding model adapter""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "embedding" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + model.use_cls_pooling = True + model.eval() + return model, tokenizer + + # A global registry for all model adapters # TODO (lmzheng): make it a priority queue. model_adapters: List[BaseModelAdapter] = [] @@ -1794,7 +1822,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("qwen-7b-chat") -class BGEAdapter(BaseModelAdapter): +class BGEAdapter(BaseEmbeddingModelAdapter): """The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)""" use_fast_tokenizer = False @@ -1802,30 +1830,8 @@ class BGEAdapter(BaseModelAdapter): def match(self, model_path: str): return "bge" in model_path.lower() - def load_model(self, model_path: str, from_pretrained_kwargs: dict): - revision = from_pretrained_kwargs.get("revision", "main") - model = AutoModel.from_pretrained( - model_path, - **from_pretrained_kwargs, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True, revision=revision - ) - if hasattr(model.config, "max_position_embeddings") and hasattr( - tokenizer, "model_max_length" - ): - model.config.max_sequence_length = min( - model.config.max_position_embeddings, tokenizer.model_max_length - ) - model.use_cls_pooling = True - model.eval() - return model, tokenizer - - def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("one_shot") - -class E5Adapter(BaseModelAdapter): +class E5Adapter(BaseEmbeddingModelAdapter): """The model adapter for E5 (e.g., intfloat/e5-large-v2)""" use_fast_tokenizer = False @@ -1833,25 +1839,32 @@ class E5Adapter(BaseModelAdapter): def match(self, model_path: str): return "e5-" in model_path.lower() - def load_model(self, model_path: str, from_pretrained_kwargs: dict): - revision = from_pretrained_kwargs.get("revision", "main") - model = AutoModel.from_pretrained( - model_path, - **from_pretrained_kwargs, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True, revision=revision - ) - if hasattr(model.config, "max_position_embeddings") and hasattr( - tokenizer, "model_max_length" - ): - model.config.max_sequence_length = min( - model.config.max_position_embeddings, tokenizer.model_max_length - ) - return model, tokenizer - def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("one_shot") +class Text2VecAdapter(BaseEmbeddingModelAdapter): + """The model adapter for text2vec (e.g., shibing624/text2vec-base-chinese)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "text2vec" in model_path.lower() + + +class M3EAdapter(BaseEmbeddingModelAdapter): + """The model adapter for m3e (e.g., moka-ai/m3e-large)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "m3e-" in model_path.lower() + + +class GTEAdapter(BaseEmbeddingModelAdapter): + """The model adapter for gte (e.g., Alibaba-NLP/gte-large-en-v1.5)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "gte-" in model_path.lower() class AquilaChatAdapter(BaseModelAdapter): @@ -2508,6 +2521,9 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(AquilaChatAdapter) register_model_adapter(BGEAdapter) register_model_adapter(E5Adapter) +register_model_adapter(Text2VecAdapter) +register_model_adapter(M3EAdapter) +register_model_adapter(GTEAdapter) register_model_adapter(Lamma2ChineseAdapter) register_model_adapter(Lamma2ChineseAlpacaAdapter) register_model_adapter(VigogneAdapter) @@ -2546,5 +2562,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(SmaugChatAdapter) register_model_adapter(Llama3Adapter) +register_model_adapter(BaseEmbeddingModelAdapter) # After all adapters, try the default base adapter. register_model_adapter(BaseModelAdapter) From d4a3376c13abace6998d4750d2f1dc491a87b071 Mon Sep 17 00:00:00 2001 From: ETZhangSX Date: Sun, 29 Sep 2024 12:07:04 +0800 Subject: [PATCH 2/2] update model_support.md --- docs/model_support.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/model_support.md b/docs/model_support.md index ba9acf5b1..6da5bc964 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -37,6 +37,7 @@ After these steps, the new model should be compatible with most FastChat feature - example: `python3 -m fastchat.serve.cli --model-path meta-llama/Llama-2-7b-chat-hf` - Vicuna, Alpaca, LLaMA, Koala - example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` +- [Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5) - [allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) - [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) - [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B) @@ -66,6 +67,7 @@ After these steps, the new model should be compatible with most FastChat feature - [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5) - [meta-math/MetaMath-7B-V1.0](https://huggingface.co/meta-math/MetaMath-7B-V1.0) - [Microsoft/Orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b) +- [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large) - [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) - example: `python3 -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat` - [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT) @@ -81,6 +83,7 @@ After these steps, the new model should be compatible with most FastChat feature - [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) - [rishiraj/CatPPT](https://huggingface.co/rishiraj/CatPPT) - [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b) +- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual) - [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b) - [tenyx/TenyxChat-7B-v1](https://huggingface.co/tenyx/TenyxChat-7B-v1) - [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)