Skip to content

Commit

Permalink
Fix-dolly-8-bit (#1656)
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-yang-1 authored Jun 12, 2023
1 parent 1fdea26 commit 640ec62
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 33 deletions.
10 changes: 6 additions & 4 deletions fastchat/model/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def __init__(self, weight=None, bias=None, device=None):

def forward(self, input: Tensor) -> Tensor:
weight = decompress(self.weight, default_compression_config)
return F.linear(input.to(weight.dtype), weight, self.bias)
if self.bias is None:
return F.linear(input.to(weight.dtype), weight)
return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype))


def compress_module(module, target_device):
Expand Down Expand Up @@ -97,10 +99,10 @@ def apply_compressed_weight(module, compressed_state_dict, target_device, prefix
)


def load_compress_model(model_path, device, torch_dtype):
def load_compress_model(model_path, device, torch_dtype, use_fast=False):
# partially load model
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
base_pattern = os.path.join(model_path, "pytorch_model-*.bin")
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)
base_pattern = os.path.join(model_path, "pytorch_model*.bin")
files = glob.glob(base_pattern)

with init_empty_weights():
Expand Down
55 changes: 26 additions & 29 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,25 @@
class BaseAdapter:
"""The base and the default model adapter."""

use_fast_tokenizer = False

def match(self, model_path: str):
return True

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=self.use_fast_tokenizer
)
model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer

def load_compress_model(self, model_path, device, torch_dtype):
return load_compress_model(
model_path, device, torch_dtype, use_fast=self.use_fast_tokenizer
)

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")

Expand Down Expand Up @@ -107,6 +116,9 @@ def load_model(
):
"""Load a model from Hugging Face."""

# get model adapter
adapter = get_model_adapter(model_path)

# Handle device mapping
cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
device, load_8bit, cpu_offloading
Expand Down Expand Up @@ -153,7 +165,7 @@ def load_model(
"8-bit quantization is not supported for multi-gpu inference."
)
else:
return load_compress_model(
return adapter.load_compress_model(
model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"]
)
elif gptq_config and gptq_config.wbits < 16:
Expand Down Expand Up @@ -332,6 +344,8 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
class DollyV2Adapter(BaseAdapter):
"""The model adapter for databricks/dolly-v2-12b"""

use_fast_tokenizer = True

def match(self, model_path: str):
return "dolly-v2" in model_path

Expand All @@ -353,44 +367,32 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
class OasstPythiaAdapter(BaseAdapter):
"""The model adapter for OpenAssistant/oasst-sft-1-pythia-12b"""

use_fast_tokenizer = True

def match(self, model_path: str):
return "oasst" in model_path and "pythia" in model_path

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("oasst_pythia")


class StableLMAdapter(BaseAdapter):
"""The model adapter for StabilityAI/stablelm-tuned-alpha-7b"""

use_fast_tokenizer = True

def match(self, model_path: str):
return "stablelm" in model_path

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("stablelm")


class MPTAdapter(BaseAdapter):
"""The model adapter for mosaicml/mpt-7b-chat"""

use_fast_tokenizer = True

def match(self, model_path: str):
return "mpt" in model_path

Expand Down Expand Up @@ -424,6 +426,8 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
class RwkvAdapter(BaseAdapter):
"""The model adapter for BlinkDL/RWKV-4-Raven"""

use_fast_tokenizer = True

def match(self, model_path: str):
return "RWKV-4" in model_path

Expand Down Expand Up @@ -465,18 +469,11 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
class PhoenixAdapter(BaseAdapter):
"""The model adapter for FreedomIntelligence/phoenix-inst-chat-7b"""

use_fast_tokenizer = True

def match(self, model_path: str):
return "phoenix" in model_path

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**from_pretrained_kwargs,
)
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("phoenix")

Expand Down

0 comments on commit 640ec62

Please sign in to comment.