Skip to content

Commit

Permalink
Merge pull request #801 from zhaoyinglia/AutoParallel/add_fuse_qkv
Browse files Browse the repository at this point in the history
[AutoParallel] add fuse_attn_qkv for gpt
  • Loading branch information
haohongxiang authored Sep 26, 2022
2 parents 8abca3e + 5593277 commit cbd50b6
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 6 deletions.
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_1.3B_dp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Model:
type_vocab_size: 16
initializer_range: 0.02
use_recompute: True
fuse_attn_qkv: True


Distributed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Model:
type_vocab_size: 16
initializer_range: 0.02
use_recompute: True
fuse_attn_qkv: True


Distributed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Model:
type_vocab_size: 16
initializer_range: 0.02
use_recompute: False
fuse_attn_qkv: True


Distributed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Model:
type_vocab_size: 16
initializer_range: 0.02
use_recompute: True
fuse_attn_qkv: True


Distributed:
Expand Down
1 change: 1 addition & 0 deletions ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Engine:
Model:
module: "GPTModuleAuto"
name: "GPT"
fuse_attn_qkv: False


Data:
Expand Down
17 changes: 12 additions & 5 deletions ppfleetx/models/language_model/gpt/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
need_weights=False,
weight_attr=None,
bias_attr=None,
fuse=False,
fuse_attn_qkv=False,
mesh=None,
mesh_idx=None):
super(MultiHeadAttention, self).__init__()
Expand All @@ -57,14 +57,14 @@ def __init__(self,
self.num_heads = num_heads
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.fuse_attn_qkv = fuse_attn_qkv
self.mesh = mesh
self.mesh_idx = mesh_idx

self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

if self.fuse:
if self.fuse_attn_qkv:
assert self.kdim == embed_dim
assert self.vdim == embed_dim
self.qkv_proj = nn.Linear(
Expand All @@ -80,6 +80,9 @@ def __init__(self,
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)

def _fuse_prepare_qkv(self, query, use_cache=False, cache=None):
auto.shard_tensor(self.qkv_proj.weight, self.mesh[self.mesh_idx],
[None, self.mesh.mp])

mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
Expand Down Expand Up @@ -220,13 +223,13 @@ def forward(self,
value = query if value is None else value
# compute q ,k ,v
if use_cache is False:
if self.fuse:
if self.fuse_attn_qkv:
q, k, v = self._fuse_prepare_qkv(query, use_cache, cache)
else:
q, k, v = self._prepare_qkv(query, key, value, use_cache,
cache)
else:
if self.fuse:
if self.fuse_attn_qkv:
q, k, v, cache = self._fuse_prepare_qkv(query, use_cache,
cache)
else:
Expand Down Expand Up @@ -342,6 +345,7 @@ def __init__(self,
normalize_before=True,
weight_attr=None,
bias_attr=None,
fuse_attn_qkv=False,
mesh=None,
mesh_idx=None):
self._config = locals()
Expand All @@ -364,6 +368,7 @@ def __init__(self,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
fuse_attn_qkv=fuse_attn_qkv,
mesh=mesh,
mesh_idx=mesh_idx)

Expand Down Expand Up @@ -483,6 +488,7 @@ def __init__(self,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
fuse_attn_qkv=False,
mesh=None):

super(GPTModelAuto, self).__init__()
Expand Down Expand Up @@ -518,6 +524,7 @@ def __init__(self,
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None,
fuse_attn_qkv=fuse_attn_qkv,
mesh=self.mesh,
mesh_idx=stages[i]))

Expand Down
1 change: 0 additions & 1 deletion projects/gpt/auto_gpt_345M_single_card.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#! /bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Expand Down
2 changes: 2 additions & 0 deletions projects/gpt/docs/auto_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Engine训练设置完成模型训练/验证/推理等过程中的参数设置,
type_vocab_size: 16
initializer_range: 0.02
use_recompute: True
fuse_attn_qkv: True
```

其中参数对应的释义如下:
Expand All @@ -105,6 +106,7 @@ Engine训练设置完成模型训练/验证/推理等过程中的参数设置,
| type_vocab_size | 词表类型 |
| initializer_range | 参数初始化的范围 |
| use_recompute | 是否使用recompute训练,重计算全部transformer |
| fuse_attn_qkv | 是否对attention层中qkv计算使用fuse代替传统Linear加速训练 |


### 数据集
Expand Down

0 comments on commit cbd50b6

Please sign in to comment.