From fcbdb6bb69d0b66cf92b1590bcb6ebfe8402dba2 Mon Sep 17 00:00:00 2001 From: czy97 Date: Tue, 22 Nov 2022 16:35:45 +0800 Subject: [PATCH] [wespeaker] add automatic mixed precision training --- examples/cnceleb/v2/conf/ecapa_tdnn.yaml | 1 + examples/cnceleb/v2/conf/ecapa_tdnn_lm.yaml | 1 + examples/cnceleb/v2/conf/repvgg.yaml | 7 ++++--- examples/cnceleb/v2/conf/resnet.yaml | 1 + examples/cnceleb/v2/conf/resnet_lm.yaml | 1 + examples/voxceleb/v2/conf/ecapa_tdnn.yaml | 7 ++++--- examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml | 7 ++++--- examples/voxceleb/v2/conf/repvgg.yaml | 9 +++++---- examples/voxceleb/v2/conf/resnet.yaml | 7 ++++--- examples/voxceleb/v2/conf/resnet_lm.yaml | 7 ++++--- examples/voxceleb/v2/conf/xvec.yaml | 7 ++++--- examples/voxceleb/v2/conf/xvec_lm.yaml | 7 ++++--- wespeaker/bin/train.py | 3 +++ wespeaker/models/pooling_layers.py | 4 ++-- wespeaker/models/projections.py | 3 +++ wespeaker/utils/executor.py | 18 ++++++++++++------ 16 files changed, 57 insertions(+), 33 deletions(-) diff --git a/examples/cnceleb/v2/conf/ecapa_tdnn.yaml b/examples/cnceleb/v2/conf/ecapa_tdnn.yaml index c4a10ed3..bbdff5e5 100644 --- a/examples/cnceleb/v2/conf/ecapa_tdnn.yaml +++ b/examples/cnceleb/v2/conf/ecapa_tdnn.yaml @@ -3,6 +3,7 @@ exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus: "[0,1]" num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training seed: 42 num_epochs: 150 diff --git a/examples/cnceleb/v2/conf/ecapa_tdnn_lm.yaml b/examples/cnceleb/v2/conf/ecapa_tdnn_lm.yaml index f3bffdf1..00396224 100644 --- a/examples/cnceleb/v2/conf/ecapa_tdnn_lm.yaml +++ b/examples/cnceleb/v2/conf/ecapa_tdnn_lm.yaml @@ -8,6 +8,7 @@ exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM gpus: "[0,1]" num_avg: 1 +enable_amp: False # whether enable automatic mixed precision training do_lm: True seed: 42 diff --git a/examples/cnceleb/v2/conf/repvgg.yaml b/examples/cnceleb/v2/conf/repvgg.yaml index bcdf2163..0318bb1b 100644 --- a/examples/cnceleb/v2/conf/repvgg.yaml +++ b/examples/cnceleb/v2/conf/repvgg.yaml @@ -3,6 +3,7 @@ exp_dir: exp/RepVGG-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus: "[0,1]" num_avg: 100 +enable_amp: False # whether enable automatic mixed precision training seed: 42 num_epochs: 150 @@ -41,11 +42,11 @@ model_init: null model_args: feat_dim: 80 embed_dim: 256 # 512 - pooling_func: 'TSTP' + pooling_func: "TSTP" deploy: False use_se: False projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -56,7 +57,7 @@ margin_update: increase_start_epoch: 20 fix_start_epoch: 40 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/examples/cnceleb/v2/conf/resnet.yaml b/examples/cnceleb/v2/conf/resnet.yaml index ace7b79a..2d877301 100644 --- a/examples/cnceleb/v2/conf/resnet.yaml +++ b/examples/cnceleb/v2/conf/resnet.yaml @@ -3,6 +3,7 @@ exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus: "[0,1]" num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training seed: 42 num_epochs: 150 diff --git a/examples/cnceleb/v2/conf/resnet_lm.yaml b/examples/cnceleb/v2/conf/resnet_lm.yaml index 6d75679c..48a251ff 100644 --- a/examples/cnceleb/v2/conf/resnet_lm.yaml +++ b/examples/cnceleb/v2/conf/resnet_lm.yaml @@ -8,6 +8,7 @@ exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM gpus: "[0,1]" num_avg: 1 +enable_amp: False # whether enable automatic mixed precision training do_lm: True seed: 42 diff --git a/examples/voxceleb/v2/conf/ecapa_tdnn.yaml b/examples/voxceleb/v2/conf/ecapa_tdnn.yaml index 55e56b76..bbdff5e5 100644 --- a/examples/voxceleb/v2/conf/ecapa_tdnn.yaml +++ b/examples/voxceleb/v2/conf/ecapa_tdnn.yaml @@ -3,6 +3,7 @@ exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus: "[0,1]" num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training seed: 42 num_epochs: 150 @@ -42,9 +43,9 @@ model_init: null model_args: feat_dim: 80 embed_dim: 192 - pooling_func: 'ASTP' + pooling_func: "ASTP" projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -55,7 +56,7 @@ margin_update: increase_start_epoch: 20 fix_start_epoch: 40 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml b/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml index 0700d603..00396224 100644 --- a/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml +++ b/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml @@ -8,6 +8,7 @@ exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM gpus: "[0,1]" num_avg: 1 +enable_amp: False # whether enable automatic mixed precision training do_lm: True seed: 42 @@ -48,9 +49,9 @@ model_init: null model_args: feat_dim: 80 embed_dim: 192 - pooling_func: 'ASTP' + pooling_func: "ASTP" projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -61,7 +62,7 @@ margin_update: increase_start_epoch: 1 fix_start_epoch: 1 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/examples/voxceleb/v2/conf/repvgg.yaml b/examples/voxceleb/v2/conf/repvgg.yaml index e4018dda..e89be21d 100644 --- a/examples/voxceleb/v2/conf/repvgg.yaml +++ b/examples/voxceleb/v2/conf/repvgg.yaml @@ -3,6 +3,7 @@ exp_dir: exp/RepVGG-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus: "[0,1]" num_avg: 100 +enable_amp: False # whether enable automatic mixed precision training seed: 42 num_epochs: 150 @@ -36,16 +37,16 @@ dataset_args: max_f: 8 prob: 0.6 -model: REPVGG_TINY_A0 # REPVGG_A0 REPVGG_A1 REPVGG_A2 REPVGG_RSBB_A0 +model: REPVGG_TINY_A0 # REPVGG_A0 REPVGG_A1 REPVGG_A2 REPVGG_RSBB_A0 model_init: null model_args: feat_dim: 80 embed_dim: 256 # 512 - pooling_func: 'TSTP' + pooling_func: "TSTP" deploy: False use_se: False projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -56,7 +57,7 @@ margin_update: increase_start_epoch: 20 fix_start_epoch: 40 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/examples/voxceleb/v2/conf/resnet.yaml b/examples/voxceleb/v2/conf/resnet.yaml index 83bfcac3..2d877301 100644 --- a/examples/voxceleb/v2/conf/resnet.yaml +++ b/examples/voxceleb/v2/conf/resnet.yaml @@ -3,6 +3,7 @@ exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus: "[0,1]" num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training seed: 42 num_epochs: 150 @@ -42,10 +43,10 @@ model_init: null model_args: feat_dim: 80 embed_dim: 256 - pooling_func: 'TSTP' + pooling_func: "TSTP" two_emb_layer: False projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -56,7 +57,7 @@ margin_update: increase_start_epoch: 20 fix_start_epoch: 40 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/examples/voxceleb/v2/conf/resnet_lm.yaml b/examples/voxceleb/v2/conf/resnet_lm.yaml index 4d4a220b..48a251ff 100644 --- a/examples/voxceleb/v2/conf/resnet_lm.yaml +++ b/examples/voxceleb/v2/conf/resnet_lm.yaml @@ -8,6 +8,7 @@ exp_dir: exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM gpus: "[0,1]" num_avg: 1 +enable_amp: False # whether enable automatic mixed precision training do_lm: True seed: 42 @@ -48,10 +49,10 @@ model_init: null model_args: feat_dim: 80 embed_dim: 256 - pooling_func: 'TSTP' + pooling_func: "TSTP" two_emb_layer: False projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -62,7 +63,7 @@ margin_update: increase_start_epoch: 1 fix_start_epoch: 1 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/examples/voxceleb/v2/conf/xvec.yaml b/examples/voxceleb/v2/conf/xvec.yaml index d9bf1a1c..0eb949b4 100644 --- a/examples/voxceleb/v2/conf/xvec.yaml +++ b/examples/voxceleb/v2/conf/xvec.yaml @@ -3,6 +3,7 @@ exp_dir: exp/XVEC-TSTP-emb512-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus: "[0,1]" num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training seed: 42 num_epochs: 150 @@ -42,9 +43,9 @@ model_init: null model_args: feat_dim: 80 embed_dim: 512 - pooling_func: 'TSTP' + pooling_func: "TSTP" projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -55,7 +56,7 @@ margin_update: increase_start_epoch: 20 fix_start_epoch: 40 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/examples/voxceleb/v2/conf/xvec_lm.yaml b/examples/voxceleb/v2/conf/xvec_lm.yaml index 7d118519..0dde5c37 100644 --- a/examples/voxceleb/v2/conf/xvec_lm.yaml +++ b/examples/voxceleb/v2/conf/xvec_lm.yaml @@ -8,6 +8,7 @@ exp_dir: exp/XVEC-TSTP-emb512-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150-LM gpus: "[0,1]" num_avg: 1 +enable_amp: False # whether enable automatic mixed precision training do_lm: True seed: 42 @@ -48,9 +49,9 @@ model_init: null model_args: feat_dim: 80 embed_dim: 512 - pooling_func: 'TSTP' + pooling_func: "TSTP" projection_args: - project_type: 'arc_margin' # add_margin, arc_margin, sphere, softmax + project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax scale: 32.0 easy_margin: False @@ -61,7 +62,7 @@ margin_update: increase_start_epoch: 1 fix_start_epoch: 1 update_margin: True - increase_type: 'exp' # exp, linear + increase_type: "exp" # exp, linear loss: CrossEntropyLoss loss_args: {} diff --git a/wespeaker/bin/train.py b/wespeaker/bin/train.py index cd2cf2ad..cb1dde29 100644 --- a/wespeaker/bin/train.py +++ b/wespeaker/bin/train.py @@ -195,6 +195,7 @@ def train(config='conf/config.yaml', **kwargs): logger.info(line) dist.barrier() # synchronize here + scaler = torch.cuda.amp.GradScaler() for epoch in range(start_epoch, configs['num_epochs'] + 1): # train_sampler.set_epoch(epoch) train_dataset.set_epoch(epoch) @@ -208,6 +209,8 @@ def train(config='conf/config.yaml', **kwargs): margin_scheduler, epoch, logger, + scaler, + enable_amp=configs['enable_amp'], log_batch_interval=configs['log_batch_interval'], device=device) diff --git a/wespeaker/models/pooling_layers.py b/wespeaker/models/pooling_layers.py index b8defa1f..91215107 100644 --- a/wespeaker/models/pooling_layers.py +++ b/wespeaker/models/pooling_layers.py @@ -47,7 +47,7 @@ def __init__(self, **kwargs): def forward(self, x): # The last dimension is the temporal axis - pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) pooling_std = pooling_std.flatten(start_dim=1) return pooling_std @@ -64,7 +64,7 @@ def __init__(self, **kwargs): def forward(self, x): # The last dimension is the temporal axis pooling_mean = x.mean(dim=-1) - pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) pooling_mean = pooling_mean.flatten(start_dim=1) pooling_std = pooling_std.flatten(start_dim=1) diff --git a/wespeaker/models/projections.py b/wespeaker/models/projections.py index a3e31a0e..8d8efaec 100644 --- a/wespeaker/models/projections.py +++ b/wespeaker/models/projections.py @@ -134,6 +134,9 @@ def __init__(self, in_features, out_features, scale=32.0, margin=0.20): in_features)) nn.init.xavier_uniform_(self.weight) + def update(self, margin): + self.margin = margin + def forward(self, input, label): # ---------------- cos(theta) & phi(theta) --------------- cosine = F.linear(F.normalize(input), F.normalize(self.weight)) diff --git a/wespeaker/utils/executor.py b/wespeaker/utils/executor.py index cf0678dc..fd16d067 100644 --- a/wespeaker/utils/executor.py +++ b/wespeaker/utils/executor.py @@ -29,6 +29,8 @@ def run_epoch(dataloader, margin_scheduler, epoch, logger, + scaler, + enable_amp, log_batch_interval=100, device=torch.device('cuda')): model.train() @@ -54,11 +56,14 @@ def run_epoch(dataloader, features = features.float().to(device) # (B,T,F) targets = targets.long().to(device) - outputs = model(features) # (embed_a,embed_b) in most cases - embeds = outputs[-1] if isinstance(outputs, tuple) else outputs - outputs = model.module.projection(embeds, targets) - loss = criterion(outputs, targets) + with torch.cuda.amp.autocast(enabled=enable_amp): + outputs = model(features) # (embed_a,embed_b) in most cases + embeds = outputs[-1] if isinstance(outputs, tuple) else outputs + outputs = model.module.projection(embeds, targets) + + loss = criterion(outputs, targets) + # loss, acc loss_meter.add(loss.item()) acc_meter.add(outputs.cpu().detach().numpy(), @@ -66,8 +71,9 @@ def run_epoch(dataloader, # updata the model optimizer.zero_grad() - loss.backward() - optimizer.step() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() # log if (i + 1) % log_batch_interval == 0: