Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intel ARC/XPU Improvements #2052

Merged
merged 10 commits into from
Jul 23, 2023
2 changes: 2 additions & 0 deletions fastchat/model/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai
tensor = None
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()

for name in model.state_dict():
if name not in linear_weights:
Expand Down
11 changes: 6 additions & 5 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,14 @@ def load_model(
adapter = get_model_adapter(model_path)
model, tokenizer = adapter.load_model(model_path, kwargs)

if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device == "mps":
if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in (
"mps",
"xpu",
):
model.to(device)

elif device == "xpu":
model.eval()
model = model.to("xpu")
model = torch.xpu.optimize(model, dtype=torch.bfloat16, inplace=True)
if device == "xpu":
model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True)

if debug:
print(model)
Expand Down
2 changes: 2 additions & 0 deletions fastchat/model/model_codet5p.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,5 @@ def __call__(
# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
2 changes: 2 additions & 0 deletions fastchat/model/model_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,5 @@ def generate_stream_falcon(
# clean
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
1 change: 1 addition & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def main(args):
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus

if args.style == "simple":
chatio = SimpleChatIO(args.multiline)
Expand Down
2 changes: 2 additions & 0 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def generate_stream(
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()


class ChatIO(abc.ABC):
Expand Down