Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled configurable auto Tensor Parallelism (TP) for the inference of diverse models #6553

Open
wants to merge 51 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
7c90529
Support pure meta model lm_head tp (#6812)
Yejing-Lai Jan 10, 2025
8c45e8c
Remove op compilation flags due to perf issue (#6944)
NirSonnenschein Jan 13, 2025
049ed93
Pin nv-a6000 workflow (#6938)
loadams Jan 13, 2025
dee5ca0
[inf] Add config var to enable keeping module on host (#6846)
oelayan7 Jan 15, 2025
251e324
`warn` to `warning` (#6952)
qgallouedec Jan 15, 2025
66b08c7
Add extra_repr to Linear classes for debugging purpose (#6954)
Xia-Weiwen Jan 16, 2025
a7e5290
Update import for torchvision.transformers (#6958)
loadams Jan 17, 2025
2090fa2
Remove Duplicate Declaration of pandas in `Dockerfile` (#6959)
Zerohertz Jan 17, 2025
7868339
Enabled configurable autoTP to run out-of-box and remain compatible w…
gyou2021 Jan 20, 2025
02303c9
Enabled Qwen2-MoE Tensor Parallism (TP) inference
gyou2021 Sep 18, 2024
2bfce64
Enabled configurable auto Tensor Parallelism (TP) for inference of di…
gyou2021 Sep 18, 2024
f200d1e
Added input examples and fixed bugs when input is None.
gyou2021 Jan 20, 2025
9651150
Added the explanation of DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS
gyou2021 Jan 20, 2025
90758a9
Fixed error names
gyou2021 Jan 21, 2025
5c6eaa4
Add the missing view operations from sequence parallel(async). (#6750)
inkcherry Jan 21, 2025
9fead9f
Update `torch.norm` to `torch.linalg.norm` and `torch.linalg.vector_n…
loadams Jan 21, 2025
989414c
Using explicit GPU upcast for ZeRO-Offload (#6962)
xylian86 Jan 21, 2025
f16f83e
Update version.txt after 0.16.3 release (#6965)
loadams Jan 21, 2025
f2b4357
Precisely track nvme optimizer offload (#6963)
tjruwase Jan 23, 2025
f48565d
Update build_win.bat script to exclue GDS op as it lacks Windows supp…
loadams Jan 24, 2025
e1c5c4d
Add CUDA 12.8 support and comment on CUDA 12.7 (#6975)
loadams Jan 28, 2025
72a8c46
Update torch versions to support 2.6 (#6977)
loadams Jan 29, 2025
826772a
generalize deepspeed linear and implement it for non cuda systems (#6…
oelayan7 Jan 29, 2025
0d00669
Update recommended Windows whl building versions (#6983)
loadams Jan 30, 2025
d1c8d9d
Title: Fix setup_env_ranks to Properly Set Environment Variables Inst…
fabiosanger Jan 30, 2025
5b38e34
Specify torchvision in nv-ds-chat workflow (prevents errors with torc…
loadams Jan 30, 2025
be98cf7
Remove assumption that padding only occurs on last rank (#6974)
xylian86 Jan 31, 2025
c8a1664
Use ds-specific module id to avoid conflicts (#6847)
tjruwase Jan 31, 2025
5ebe271
Update A6000 workflows to use newer docker container - 24.09 vs 24.03…
loadams Jan 31, 2025
b8ba88b
Allow NVIDIA Blackwell (#6991)
fabiendupont Feb 4, 2025
7b24b47
Update GH org references (#6998)
tjruwase Feb 5, 2025
68d924e
Update CNAME
loadams Feb 5, 2025
1c08a94
Update CNAME
loadams Feb 5, 2025
a531625
[XPU] max1100 workflow update for docker and softwares (#7003)
Liangliang-Ma Feb 5, 2025
f583d30
autotp training(fix dco) (#7004)
inkcherry Feb 5, 2025
a493b22
import triton files when triton is supported and installed (#6989)
oelayan7 Feb 6, 2025
8a36b54
Update A6000 tests transformers version (#7016)
loadams Feb 8, 2025
0f8687f
Fix ds-chat CI regression (#7015)
tjruwase Feb 10, 2025
b3b7e79
[Ulysses tutorial] typos (#7024)
stas00 Feb 11, 2025
4d6b2ab
fix hostname -I for macOS #6497 (#6990)
fitzjalen Feb 12, 2025
f7e6f9b
Update workflows to cuda 12.4 (#7000)
loadams Feb 12, 2025
2ce885b
[ROCm] Enable fp_quantizer on ROCm (#7027)
rraminen Feb 13, 2025
40eca62
add gds chinese blog (#7034)
GuanhuaWang Feb 13, 2025
0a05120
Add chinese blog for deepspeed windows, and fix format (#7035)
hwchen2017 Feb 14, 2025
dbb3b09
AIO on ROCM (#7023)
jomayeri Feb 14, 2025
d32af71
Merge branch 'master' into configurable_autoTP
gyou2021 Feb 18, 2025
9bac81a
Merge branch 'master' into configurable_autoTP
delock Feb 19, 2025
9bea3f9
Merge branch 'master' into configurable_autoTP
gyou2021 Feb 20, 2025
d2214d0
Merge branch 'master' into configurable_autoTP
loadams Feb 20, 2025
3584c77
Merge branch 'master' into configurable_autoTP
gyou2021 Feb 21, 2025
b2adb33
Merge branch 'master' into configurable_autoTP
gyou2021 Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 93 additions & 24 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list

import os
import ast
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode

Expand Down Expand Up @@ -282,6 +285,7 @@ def kernel_supported(module_list):
return True
return False

## tp parser based on autoTP config in environment
def tp_parser(model):
policy_list = []
module_list = []
Expand All @@ -292,39 +296,70 @@ def tp_parser(model):
assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
norm_layer_name_list = ['LayerNorm', 'layer_norm', 'ln_1', 'ln_2']

default_ds_common_reduceLinear_keys = ['out_proj', 'o_proj', 'down_proj']
#the different model names of the same key are concat with comma
predefined_ds_common_reduceLinear_items = {
'attention.dense': 'GPTNeoX',
'self_attention.dense': 'falcon,ChatGLM,Phi',
'w2': 'Mixtral',
'dense_4h_to_h': 'ChatGLM'
}
ds_reduceLinear_items = predefined_ds_common_reduceLinear_items
#'DS_ALL_REDUCE_LINEAR_ITEMS' is a dictionary whose keys are layer names of LinearAllreduce and
#whose values are keywords in the module name.
# If the same layer name in multiple models is LinearAllreduce, concat the keywords of the different module names with comma
# import os
# os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword','layer_name_2':'model_1 keyword',...},"
# os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"] = "{'layer_name_1':'model-1 keyword,model-2 keyword,...',
# 'layer_name_2':'model-1 keyword,model-2 keyword,...',...}"
# for example: os.environ["DS_ALL_REDUCE_LINEAR_ITEMS"]="{'w2','mixtral'}"
ds_user_reduceLinear_items = os.environ.get('DS_ALL_REDUCE_LINEAR_ITEMS')
if ds_user_reduceLinear_items:
ds_user_reduceLinear_items = ast.literal_eval(ds_user_reduceLinear_items)
ds_reduceLinear_items.update(ds_user_reduceLinear_items)

ds_reduceLinear_keys = ds_reduceLinear_items.keys()

#'DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS' is a list. The layer name in the list will be removed from those of default common LinearAllReduce.
# import os
# os.environ["DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS"] = "['layer_name_1', 'layer_name_2',...]"
#for example: os.environ["DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS"] = "['o_proj']"
ds_user_remove_reduceLinear_keys = os.environ.get('DS_REMOVED_COMMON_REDUCE_LINEAR_KEYS')
if ds_user_remove_reduceLinear_keys:
ds_user_remove_reduceLinear_keys = ast.literal_eval(ds_user_remove_reduceLinear_keys)
ds_common_reduceLinear_keys = [
item for item in default_ds_common_reduceLinear_keys if item not in ds_user_remove_reduceLinear_keys
]
else:
ds_common_reduceLinear_keys = default_ds_common_reduceLinear_keys

#ln_1 , ln_2 for Qwen
for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key in norm_layer_name_list:
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
module_name = str(type(module))

for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
elif 'o_proj' in layer:
gem_list = gem_list + [layer]
elif 'down_proj' in layer:
gem_list = gem_list + [layer]
elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'falcon' in str(
type(module)): # this is a hack to get the right linear layer for this model!
gem_list = gem_list + [layer]
# Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce.
elif 'w2' in layer and 'Mixtral' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attn.dense' in layer and 'Phi' in str(type(module)):
gem_list = gem_list + [layer]
elif 'self_attention.dense' in layer and 'ChatGLM' in str(model):
gem_list = gem_list + [layer]
elif 'dense_4h_to_h' in layer and 'ChatGLM' in str(model):

if any((key in layer) for key in ds_common_reduceLinear_keys):
gem_list = gem_list + [layer]
continue

for key in ds_reduceLinear_keys:
if key in layer:
values = ds_reduceLinear_items[key].split(',')
if any((v in module_name) for v in values):
gem_list = gem_list + [layer]
break

layer_list = []
if gem_list != []:
Expand Down Expand Up @@ -357,9 +392,38 @@ def _replace(self, child, name, conv_linear_layer):
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
predefined_keep_Linear_Items = {
'q_a_proj': 'DeepseekV2',
'kv_a_proj_with_mqa': 'DeepseekV2',
'block_sparse_moe.gate': 'DeepseekV2',
'mlp.shared_expert_gate': 'qwen2_moe',
'mlp.gate': 'qwen2_moe'
}
keep_Linear_Items = predefined_keep_Linear_Items
#DS_KEEP_LINEAR_ITEMS is a dictionary whose keys are layer names of Linear and
#whose values are keywords in the module name.
#If the same layer name in multiple models keeps Linear, concat the keywords of the different module names with comma
# import os
# one keyword for one model
# os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword','layer_name_2':'model_1 keyword',...}"
# one keyword for multiple models
# os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'layer_name_1':'model_1 keyword,model_2 keyword,...',
# 'layer_name_2':'model_1 keyword,model_2 keyword,...',...}"
#for example:
# os.environ["DS_KEEP_LINEAR_ITEMS"] = "{'gate':'mixtral'}"
user_keep_Linear_Items = os.environ.get('DS_KEEP_LINEAR_ITEMS')
if user_keep_Linear_Items:
user_keep_Linear_Items = ast.literal_eval(user_keep_Linear_Items)
keep_Linear_Items.update(user_keep_Linear_Items)

keys = keep_Linear_Items.keys()
for item in keys:
if item in name:
values = keep_Linear_Items[item]
values = values.split(',') #the different model names of the same key are concat with comma
if any((v in str(type(self.module))) for v in values):
return child

# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
Expand Down Expand Up @@ -466,7 +530,12 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
keepLinearItems = os.environ['keepLinearItems']
keepLinearItems = ast.literal_eval(keepLinearItems)

if any(item not in checking_key for item in keepLinearItems):
setattr(
r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
elif any(isinstance(child, lp) for lp in self.linear_policies):
# Added for falcon model support
Expand Down
Loading