-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] lm-sys/FastChat/issues/2295 #2328
Conversation
It seems your modification is not related to the picture you posted in your issue. |
On mac os (usually device:mps) did not support bfloat. Maybe the reason is not correct.But it works with |
In the two if branches you added. You added ".half()" in one "if" branch and one "else" branch. Why is this? |
Oh,that's my mistake.I add half() to these two if under mps device on my local code base.Sorry about it.Now,I fix it. |
@@ -167,12 +167,18 @@ def load_compress_model(model_path, device, torch_dtype, use_fast, revision="mai | |||
tmp_state_dict = torch.load(filename, map_location=lambda storage, loc: storage) | |||
for name in tmp_state_dict: | |||
if name in linear_weights: | |||
tensor = tmp_state_dict[name].to(device).data.to(torch_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you try something like this instead of if/else?
tensor = tmp_state_dict[name].to(torch_dtype).to(device).data
If it works, apply the same change to L178.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to check device type (mps).just use dtype argument. It works for chatglm2-6b raw model from huggingface.
Why are these changes needed?
Related issue number (if applicable)
#2295
Checks
format.sh
to lint the changes in this PR.