Skip to content

Commit

Permalink
Avoid naming collision on partition()
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase committed Feb 7, 2025
1 parent 4a1dd0f commit 0ac4457
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def gather_params(self, params_list):
pass

@abstractmethod
def partition(self, params_list: List[torch.Tensor]):
def _tp_partition(self, params_list: List[torch.Tensor]):
"""
Partitions the parameters for tensor parallelism.
It is necessary to ensure that this function only involves the logic of params partitioning.
Expand Down Expand Up @@ -205,7 +205,7 @@ def config_tp_params(self, weight):
setattr(weight, DS_TENSOR_MODEL_PARALLEL, True)
setattr(weight, DS_IS_REPLACED_MODULE, True)
weight.gather_params = self.gather_params
weight.partition = self.partition
weight._tp_partition = self._tp_partition

def is_training_mode(self):
global DEEPSPEED_AUTOTP_MODE
Expand Down Expand Up @@ -294,7 +294,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
"""
#TODO : Check whether there are any missing attributes.
if self.enabled:
self.params[0].partition(self.params)
self.params[0]._tp_partition(self.params)


class LinearAllreduce(TensorParallel_Layer):
Expand All @@ -304,7 +304,7 @@ def __init__(self, module, mp_group, **kwargs):
self.weight = module.weight
self.bias = module.bias

self.partition([self.weight, self.bias])
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
Expand Down Expand Up @@ -335,7 +335,7 @@ def gather_params(self, params_list):
return

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):

if not self.is_training_mode():
self.uneven_partition(params_list)
Expand Down Expand Up @@ -374,14 +374,14 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs):
self.weight = module.weight
self.bias = module.bias
if not skip_partition:
self.partition([self.weight, self.bias])
self._tp_partition([self.weight, self.bias])
self.support_training = True
self.config_tp_params(self.weight)
if self.bias is not None:
self.config_tp_params(self.bias)

def forward(self, input):
if self.mp_group is not None:
if getattr(self, 'mp_group', None) is not None:
input = ColumnParallel.apply(self.mp_group, input)
output = torch.matmul(input, self.weight.transpose(-1, -2))
if self.bias is not None:
Expand All @@ -402,7 +402,7 @@ def gather_params(self, params_list):
params_list[idx].data = output_param.contiguous()

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):

if not self.is_training_mode():
self.uneven_partition(params_list)
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(self, module, mp_group, skip_partition=False, **kwargs):
super().__init__(module, mp_group, skip_partition, **kwargs)

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
Expand All @@ -482,7 +482,7 @@ def partition(self, params_list):
class conv_LinearLayer(LinearLayer):

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight = None
bias = None
if len(params_list) == 1:
Expand All @@ -507,7 +507,7 @@ class Yuan_LinearAllreduce(LinearAllreduce):

#Yuan2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, False)
params_list[0].data = weight
Expand All @@ -518,7 +518,7 @@ def partition(self, params_list):
class Yuan_LinearLayer(LinearLayer):
#Yuan2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, True)
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
Expand All @@ -529,7 +529,7 @@ def partition(self, params_list):
class GateUpPack_LinearLayer(LinearLayer):
# chatGLM2, chatGLM2
@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach()
if bias is not None:
Expand All @@ -539,7 +539,7 @@ def partition(self, params_list):
class Conv_LinearALlreduce(LinearAllreduce):

@torch.no_grad()
def partition(self, params_list):
def _tp_partition(self, params_list):
for idx, param in enumerate(params_list):
if param is None:
return
Expand Down

0 comments on commit 0ac4457

Please sign in to comment.