Skip to content

Commit

Permalink
[model] MQMHA+arc_margin_intertopk_subcenter (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hunterhuan authored Nov 30, 2022
1 parent 7cf9def commit b7c2e32
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 30 deletions.
4 changes: 2 additions & 2 deletions examples/voxceleb/v2/conf/resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 256
pooling_func: "TSTP"
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
projection_args:
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

Expand Down
4 changes: 2 additions & 2 deletions examples/voxceleb/v2/conf/resnet_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ model_init: null
model_args:
feat_dim: 80
embed_dim: 256
pooling_func: "TSTP"
pooling_func: "TSTP" # TSTP, ASTP, MQMHASTP
two_emb_layer: False
projection_args:
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax
project_type: "arc_margin" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

Expand Down
1 change: 1 addition & 0 deletions wespeaker/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def train(config='conf/config.yaml', **kwargs):
if configs.get('do_lm', False):
logger.info('No speed perturb while doing large margin fine-tuning')
configs['dataset_args']['speed_perturb'] = False
configs['projection_args']['do_lm'] = configs.get('do_lm', False)
projection = get_projection(configs['projection_args'])
model.add_module("projection", projection)
if rank == 0:
Expand Down
1 change: 1 addition & 0 deletions wespeaker/bin/train_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def train(config='conf/config.yaml', **kwargs):
if configs['feature_args']['raw_wav'] and configs['dataset_args']['speed_perturb']:
# diff speed is regarded as diff spk
configs['projection_args']['num_class'] *= 3
configs['projection_args']['do_lm'] = config.get('do_lm', False)
projection = get_projection(configs['projection_args'])
model.add_module("projection", projection)
if rank == 0:
Expand Down
8 changes: 4 additions & 4 deletions wespeaker/models/ecapa_tdnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def __init__(self,
cat_channels = channels * 3
out_channels = 512 * 3
self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=out_channels, global_context_att=global_context_att)
self.bn = nn.BatchNorm1d(out_channels * self.n_stats)
self.linear = nn.Linear(out_channels * self.n_stats, embed_dim)
self.pool_out_dim = self.pool.get_out_dim()
self.bn = nn.BatchNorm1d(self.pool_out_dim)
self.linear = nn.Linear(self.pool_out_dim, embed_dim)

def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
Expand Down Expand Up @@ -247,7 +247,7 @@ def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func='ASTP'):
x = torch.zeros(10, 200, 80)
model = ECAPA_TDNN_GLOB_c512(feat_dim=80,
embed_dim=192,
pooling_func='ASTP')
pooling_func='MQMHASTP')
model.eval()
out = model(x)
print(out.shape)
Expand Down
185 changes: 179 additions & 6 deletions wespeaker/models/pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Pooling functions to aggregate frame-level deep features
into segment-level speaker embeddings
Expand All @@ -22,62 +21,82 @@

import torch
import torch.nn as nn
import torch.nn.functional as F


class TAP(nn.Module):
"""
Temporal average pooling, only first-order mean is considered
"""
def __init__(self, **kwargs):

def __init__(self, in_dim=0, **kwargs):
super(TAP, self).__init__()
self.in_dim = in_dim

def forward(self, x):
pooling_mean = x.mean(dim=-1)
# To be compatable with 2D input
pooling_mean = pooling_mean.flatten(start_dim=1)
return pooling_mean

def get_out_dim(self):
self.out_dim = self.in_dim
return self.out_dim


class TSDP(nn.Module):
"""
Temporal standard deviation pooling, only second-order std is considered
"""
def __init__(self, **kwargs):

def __init__(self, in_dim=0, **kwargs):
super(TSDP, self).__init__()
self.in_dim = in_dim

def forward(self, x):
# The last dimension is the temporal axis
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_std = pooling_std.flatten(start_dim=1)
return pooling_std

def get_out_dim(self):
self.out_dim = self.in_dim
return self.out_dim


class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, **kwargs):

def __init__(self, in_dim=0, **kwargs):
super(TSTP, self).__init__()
self.in_dim = in_dim

def forward(self, x):
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)

stats = torch.cat((pooling_mean, pooling_std), 1)
return stats

def get_out_dim(self):
self.out_dim = self.in_dim * 2
return self.out_dim


class ASTP(nn.Module):
""" Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):

def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False, **kwargs):
super(ASTP, self).__init__()
self.in_dim = in_dim
self.global_context_att = global_context_att

# Use Conv1d with stride == 1 rather than Linear, then we don't
Expand Down Expand Up @@ -119,3 +138,157 @@ def forward(self, x):
var = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-10))
return torch.cat([mean, std], dim=1)

def get_out_dim(self):
self.out_dim = 2 * self.in_dim
return self.out_dim


class MHASTP(torch.nn.Module):
""" Multi head attentive statistics pooling
Reference:
Self Multi-Head Attention for Speaker Recognition
https://arxiv.org/pdf/1906.09890.pdf
"""

def __init__(self,
in_dim,
layer_num=2,
head_num=2,
d_s=1,
bottleneck_dim=64,
**kwargs):
super(MHASTP, self).__init__()
assert (in_dim % head_num
) == 0 # make sure that head num can be divided by input_dim
self.in_dim = in_dim
self.head_num = head_num
d_model = int(in_dim / head_num)
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
if d_s > 1:
d_s = d_model
else:
d_s = 1
self.d_s = d_s
channel_dims[0], channel_dims[-1] = d_model, d_s
heads_att_trans = []
for i in range(self.head_num):
att_trans = nn.Sequential()
for i in range(layer_num - 1):
att_trans.add_module(
'att_' + str(i),
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1))
att_trans.add_module('tanh' + str(i), nn.Tanh())
att_trans.add_module(
'att_' + str(layer_num - 1),
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num],
1, 1))
heads_att_trans.append(att_trans)
self.heads_att_trans = nn.ModuleList(heads_att_trans)

def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x F x T
input = input.reshape(input.shape[0],
input.shape[1] * input.shape[2],
input.shape[3])
assert len(input.shape) == 3
bs, f_dim, t_dim = input.shape
chunks = torch.chunk(input, self.head_num, 1)
# split
chunks_out = []
# for i in range(self.head_num):
# att_score = self.heads_att_trans[i](chunks[i])
for i, layer in enumerate(self.heads_att_trans):
att_score = layer(chunks[i])
alpha = F.softmax(att_score, dim=-1)
mean = torch.sum(alpha * chunks[i], dim=2)
var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-10))
chunks_out.append(torch.cat((mean, std), dim=1))
out = torch.cat(chunks_out, dim=1)
return out

def get_out_dim(self):
self.out_dim = 2 * self.in_dim
return self.out_dim


class MQMHASTP(torch.nn.Module):
""" An attentive pooling
Reference:
multi query multi head attentive statistics pooling
https://arxiv.org/pdf/2110.05042.pdf
Args:
in_dim: the feature dimension of input
layer_num: the number of layer in the pooling layer
query_num: the number of querys
head_num: the number of heads
bottleneck_dim: the bottleneck dimension
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
https://arxiv.org/pdf/1906.09890.pdf
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
https://arxiv.org/pdf/1803.10963.pdf
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
"""

def __init__(self,
in_dim,
layer_num=2,
query_num=2,
head_num=8,
d_s=2,
bottleneck_dim=64,
**kwargs):
super(MQMHASTP, self).__init__()
self.n_query = nn.ModuleList([
MHASTP(in_dim,
layer_num=layer_num,
head_num=head_num,
d_s=d_s,
bottleneck_dim=bottleneck_dim) for i in range(query_num)
])
self.query_num = query_num
self.in_dim = in_dim

def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x F x T
input = input.reshape(input.shape[0],
input.shape[1] * input.shape[2],
input.shape[3])
assert len(input.shape) == 3
res = []
for i, layer in enumerate(self.n_query):
res.append(layer(input))
out = torch.cat(res, dim=-1)
return out

def get_out_dim(self):
self.out_dim = self.in_dim * 2 * self.query_num
return self.out_dim


if __name__ == '__main__':
data = torch.randn(16, 512, 10, 35)
# model = StatisticsPooling()
model = MQMHASTP(512 * 10)
model = MHASTP(512 * 10)
model = MQMHASTP(512 * 10, context=False)
print(model)

out = model(data)
print(out.shape)
print(model.get_out_dim())
Loading

0 comments on commit b7c2e32

Please sign in to comment.