Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wespeaker] add automatic mixed precision training #103

Merged
merged 1 commit into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cnceleb/v2/conf/ecapa_tdnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 10
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
Expand Down
1 change: 1 addition & 0 deletions examples/cnceleb/v2/conf/ecapa_tdnn_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM
gpus: "[0,1]"
num_avg: 1
enable_amp: False # whether enable automatic mixed precision training
do_lm: True

seed: 42
Expand Down
7 changes: 4 additions & 3 deletions examples/cnceleb/v2/conf/repvgg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exp_dir: exp/RepVGG-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 100
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
Expand Down Expand Up @@ -41,11 +42,11 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 256 # 512
pooling_func: 'TSTP'
pooling_func: "TSTP"
deploy: False
use_se: False
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -56,7 +57,7 @@ margin_update:
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
1 change: 1 addition & 0 deletions examples/cnceleb/v2/conf/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 10
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
Expand Down
1 change: 1 addition & 0 deletions examples/cnceleb/v2/conf/resnet_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM
gpus: "[0,1]"
num_avg: 1
enable_amp: False # whether enable automatic mixed precision training
do_lm: True

seed: 42
Expand Down
7 changes: 4 additions & 3 deletions examples/voxceleb/v2/conf/ecapa_tdnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 10
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
Expand Down Expand Up @@ -42,9 +43,9 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 192
pooling_func: 'ASTP'
pooling_func: "ASTP"
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -55,7 +56,7 @@ margin_update:
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
7 changes: 4 additions & 3 deletions examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM
gpus: "[0,1]"
num_avg: 1
enable_amp: False # whether enable automatic mixed precision training
do_lm: True

seed: 42
Expand Down Expand Up @@ -48,9 +49,9 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 192
pooling_func: 'ASTP'
pooling_func: "ASTP"
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -61,7 +62,7 @@ margin_update:
increase_start_epoch: 1
fix_start_epoch: 1
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
9 changes: 5 additions & 4 deletions examples/voxceleb/v2/conf/repvgg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exp_dir: exp/RepVGG-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 100
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
Expand Down Expand Up @@ -36,16 +37,16 @@ dataset_args:
max_f: 8
prob: 0.6

model: REPVGG_TINY_A0 # REPVGG_A0 REPVGG_A1 REPVGG_A2 REPVGG_RSBB_A0
model: REPVGG_TINY_A0 # REPVGG_A0 REPVGG_A1 REPVGG_A2 REPVGG_RSBB_A0
model_init: null
model_args:
feat_dim: 80
embed_dim: 256 # 512
pooling_func: 'TSTP'
pooling_func: "TSTP"
deploy: False
use_se: False
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -56,7 +57,7 @@ margin_update:
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
7 changes: 4 additions & 3 deletions examples/voxceleb/v2/conf/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 10
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
Expand Down Expand Up @@ -42,10 +43,10 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 256
pooling_func: 'TSTP'
pooling_func: "TSTP"
two_emb_layer: False
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -56,7 +57,7 @@ margin_update:
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
7 changes: 4 additions & 3 deletions examples/voxceleb/v2/conf/resnet_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM
gpus: "[0,1]"
num_avg: 1
enable_amp: False # whether enable automatic mixed precision training
do_lm: True

seed: 42
Expand Down Expand Up @@ -48,10 +49,10 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 256
pooling_func: 'TSTP'
pooling_func: "TSTP"
two_emb_layer: False
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -62,7 +63,7 @@ margin_update:
increase_start_epoch: 1
fix_start_epoch: 1
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
7 changes: 4 additions & 3 deletions examples/voxceleb/v2/conf/xvec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
exp_dir: exp/XVEC-TSTP-emb512-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus: "[0,1]"
num_avg: 10
enable_amp: False # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
Expand Down Expand Up @@ -42,9 +43,9 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 512
pooling_func: 'TSTP'
pooling_func: "TSTP"
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -55,7 +56,7 @@ margin_update:
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
7 changes: 4 additions & 3 deletions examples/voxceleb/v2/conf/xvec_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
exp_dir: exp/XVEC-TSTP-emb512-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM
gpus: "[0,1]"
num_avg: 1
enable_amp: False # whether enable automatic mixed precision training
do_lm: True

seed: 42
Expand Down Expand Up @@ -48,9 +49,9 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 512
pooling_func: 'TSTP'
pooling_func: "TSTP"
projection_args:
project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
scale: 32.0
easy_margin: False

Expand All @@ -61,7 +62,7 @@ margin_update:
increase_start_epoch: 1
fix_start_epoch: 1
update_margin: True
increase_type: 'exp' # exp, linear
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}
Expand Down
3 changes: 3 additions & 0 deletions wespeaker/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def train(config='conf/config.yaml', **kwargs):
logger.info(line)
dist.barrier() # synchronize here

scaler = torch.cuda.amp.GradScaler()
for epoch in range(start_epoch, configs['num_epochs'] + 1):
# train_sampler.set_epoch(epoch)
train_dataset.set_epoch(epoch)
Expand All @@ -208,6 +209,8 @@ def train(config='conf/config.yaml', **kwargs):
margin_scheduler,
epoch,
logger,
scaler,
enable_amp=configs['enable_amp'],
log_batch_interval=configs['log_batch_interval'],
device=device)

Expand Down
4 changes: 2 additions & 2 deletions wespeaker/models/pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, **kwargs):

def forward(self, x):
# The last dimension is the temporal axis
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_std = pooling_std.flatten(start_dim=1)
return pooling_std

Expand All @@ -64,7 +64,7 @@ def __init__(self, **kwargs):
def forward(self, x):
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)

Expand Down
3 changes: 3 additions & 0 deletions wespeaker/models/projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(self, in_features, out_features, scale=32.0, margin=0.20):
in_features))
nn.init.xavier_uniform_(self.weight)

def update(self, margin):
self.margin = margin

def forward(self, input, label):
# ---------------- cos(theta) & phi(theta) ---------------
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
Expand Down
18 changes: 12 additions & 6 deletions wespeaker/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def run_epoch(dataloader,
margin_scheduler,
epoch,
logger,
scaler,
enable_amp,
log_batch_interval=100,
device=torch.device('cuda')):
model.train()
Expand All @@ -54,20 +56,24 @@ def run_epoch(dataloader,

features = features.float().to(device) # (B,T,F)
targets = targets.long().to(device)
outputs = model(features) # (embed_a,embed_b) in most cases
embeds = outputs[-1] if isinstance(outputs, tuple) else outputs
outputs = model.module.projection(embeds, targets)

loss = criterion(outputs, targets)
with torch.cuda.amp.autocast(enabled=enable_amp):
outputs = model(features) # (embed_a,embed_b) in most cases
embeds = outputs[-1] if isinstance(outputs, tuple) else outputs
outputs = model.module.projection(embeds, targets)

loss = criterion(outputs, targets)

# loss, acc
loss_meter.add(loss.item())
acc_meter.add(outputs.cpu().detach().numpy(),
targets.cpu().numpy())

# updata the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

# log
if (i + 1) % log_batch_interval == 0:
Expand Down