Skip to content

Commit

Permalink
Update qwen and add pygmalion (#2607)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trangle authored Oct 28, 2023
1 parent cbf2853 commit 09e4357
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,20 @@ class QwenChatAdapter(BaseModelAdapter):
def match(self, model_path: str):
return "qwen" in model_path.lower()

def float_set(self, config, option):
config.bf16 = False
config.fp16 = False
config.fp32 = False

if option == "bf16":
config.bf16 = True
elif option == "fp16":
config.fp16 = True
elif option == "fp32":
config.fp32 = True
else:
print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.")

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
from transformers.generation import GenerationConfig

Expand All @@ -1430,7 +1444,7 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
)
# NOTE: if you use the old version of model file, please remove the comments below
# config.use_flash_attn = False
config.fp16 = True
self.float_set(config, "fp16")
generation_config = GenerationConfig.from_pretrained(
model_path, trust_remote_code=True
)
Expand Down Expand Up @@ -1698,6 +1712,20 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("lemur-70b-chat")


class PygmalionAdapter(BaseModelAdapter):
"""The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)"""

# use_fast_tokenizer = False

def match(self, model_path: str):
return bool(
re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I)
)

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("metharme")


# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
Expand Down Expand Up @@ -1760,6 +1788,8 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(ZephyrAdapter)
register_model_adapter(XwinLMAdapter)
register_model_adapter(LemurAdapter)
register_model_adapter(PygmalionAdapter)


# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)

0 comments on commit 09e4357

Please sign in to comment.