From f4d937a2feff16d7b770ecb9296aef812a72f5ce Mon Sep 17 00:00:00 2001 From: "Lai, Yejing" Date: Thu, 28 Nov 2024 05:34:19 -0800 Subject: [PATCH 1/3] add lm_head tp when checkpoint is none --- deepspeed/module_inject/replace_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 7afe6ca903fb..eabb873c4bd4 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -386,7 +386,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size): checkpoint=checkpoint_file) pbar.update(1) gc.collect() - replaced_module = set_lm_head(replaced_module) # conv2d tp module replace # Now is for yuan model. Add model list and conv policy to decide whether to replace conv. if 'Yuan' in str(replaced_module): @@ -396,6 +395,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): orig_class=orig_layer_impl, replace_fn=replace_fn, _replace_policy=config.injection_policy_tuple) + replaced_module = set_lm_head(replaced_module) quantizer = GroupQuantizer(q_int8=quantize) world_size = dist.get_world_size() if dist.is_initialized() else 1 From 2b5e7393b65434fa02133c8c77beba9cd7dd7435 Mon Sep 17 00:00:00 2001 From: "Lai, Yejing" Date: Sun, 1 Dec 2024 19:41:03 -0800 Subject: [PATCH 2/3] add lm_head meta replace --- deepspeed/module_inject/replace_module.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index eabb873c4bd4..8b2944fdaf53 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -339,13 +339,11 @@ def set_lm_head(module): module.lm_head, "weight") and module.lm_head.weight.is_meta: module.lm_head.weight = embedding_weight # enable tensor parallel for the last linear - if hasattr(module, "lm_head") and hasattr(module.lm_head, - "weight") and not module.lm_head.weight.is_meta and isinstance( - module.lm_head, torch.nn.Linear): + if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance( + module.lm_head, torch.nn.Linear): module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head") - elif hasattr(module, "embed_out") and hasattr(module.embed_out, - "weight") and not module.embed_out.weight.is_meta and isinstance( - module.embed_out, torch.nn.Linear): + elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance( + module.embed_out, torch.nn.Linear): module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"): module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head") From aec4e3f89d905d9fb8027f8e04b3b8d55993f272 Mon Sep 17 00:00:00 2001 From: "Lai, Yejing" Date: Tue, 7 Jan 2025 20:55:40 -0800 Subject: [PATCH 3/3] Update AutoTP default set lm_head tp --- deepspeed/module_inject/replace_module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 67ff769af80d..00b22aac81d8 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -396,7 +396,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size): orig_class=orig_layer_impl, replace_fn=replace_fn, _replace_policy=config.injection_policy_tuple) - replaced_module = set_lm_head(replaced_module) + # AutoTP default set lm_head tp + if not config.replace_with_kernel_inject: + replaced_module = set_lm_head(replaced_module) quantizer = GroupQuantizer(q_int8=quantize) world_size = dist.get_world_size() if dist.is_initialized() else 1