Skip to content

Commit

Permalink
Add Intel AMX/AVX512 support to accelerate inference (#2247)
Browse files Browse the repository at this point in the history
Signed-off-by: LeiZhou-97 <[email protected]>
  • Loading branch information
LeiZhou-97 authored Aug 21, 2023
1 parent d294434 commit 50c5f0f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ This runs on the CPU only and does not require GPU. It requires around 30GB of C
python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.3 --device cpu
```

Use Intel AI Accelerator AVX512_BF16/AMX to accelerate CPU inference.
```
CPU_ISA=amx python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.3 --device cpu
```

#### Metal Backend (Mac Computers with Apple Silicon or AMD GPUs)
Use `--device mps` to enable GPU acceleration on Mac computers (requires torch >= 2.0).
Use `--load-8bit` to turn on 8-bit compression.
Expand Down
2 changes: 2 additions & 0 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
SESSION_EXPIRATION_TIME = 3600
# The output dir of log files
LOGDIR = os.getenv("LOGDIR", ".")
# CPU Instruction Set Architecture
CPU_ISA = os.getenv("CPU_ISA")


##### For the controller and workers (could be overwritten through ENV variables.)
Expand Down
19 changes: 19 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import Dict, List, Optional
import warnings

from fastchat.constants import CPU_ISA


if sys.version_info >= (3, 9):
from functools import cache
else:
Expand Down Expand Up @@ -167,6 +170,15 @@ def load_model(
)
if device == "cpu":
kwargs = {"torch_dtype": torch.float32}
if CPU_ISA in ["avx512_bf16", "amx"]:
try:
import intel_extension_for_pytorch as ipex

kwargs = {"torch_dtype": torch.bfloat16}
except ImportError:
warnings.warn(
"Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
)
elif device == "cuda":
kwargs = {"torch_dtype": torch.float16}
if num_gpus != 1:
Expand Down Expand Up @@ -267,6 +279,13 @@ def load_model(
# Load model
model, tokenizer = adapter.load_model(model_path, kwargs)

if (
device == "cpu"
and kwargs["torch_dtype"] is torch.bfloat16
and CPU_ISA is not None
):
model = ipex.optimize(model, dtype=kwargs["torch_dtype"])

if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
"mps",
"xpu",
Expand Down

0 comments on commit 50c5f0f

Please sign in to comment.