# -*- coding: utf-8 -*-
#
# This file is part of SIDEKIT.
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is a python package for speaker verification.
# Home page: http://www-lium.univ-lemans.fr/sidekit/
#
# SIDEKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU LLesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# SIDEKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Copyright 2014-2020 Yevhenii Prokopalo, Anthony Larcher
"""
import logging
import numpy
import pandas
import pickle
import shutil
import torch
import torch.optim as optim
import torch.multiprocessing as mp
import yaml
from torchvision import transforms
from collections import OrderedDict
from sidekit.nnet.xsets import XvectorMultiDataset, StatDataset, VoxDataset, SideSet
from sidekit.nnet.xsets import IdMapSet
from sidekit.nnet.xsets import FrequencyMask, CMVN, TemporalMask, MFCC
from sidekit.bosaris import IdMap
from sidekit.statserver import StatServer
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sidekit.nnet.sincnet import SincNet
__license__ = "LGPL"
__author__ = "Anthony Larcher"
__copyright__ = "Copyright 2015-2020 Anthony Larcher"
__maintainer__ = "Anthony Larcher"
__email__ = "anthony.larcher@univ-lemans.fr"
__status__ = "Production"
__docformat__ = 'reS'
[docs]def get_lr(optimizer):
"""
:param optimizer:
:return:
"""
for param_group in optimizer.param_groups:
return param_group['lr']
[docs]def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
"""
:param state:
:param is_best:
:param filename:
:param best_filename:
:return:
"""
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, best_filename)
[docs]class Xtractor(torch.nn.Module):
"""
Class that defines an x-vector extractor based on 5 convolutional layers and a mean standard deviation pooling
"""
def __init__(self, speaker_number, model_archi=None):
"""
If config is None, default architecture is created
:param model_archi:
"""
super(Xtractor, self).__init__()
self.speaker_number = speaker_number
self.feature_size = None
if model_archi is None:
self.feature_size = 30
self.activation = torch.nn.LeakyReLU(0.2)
self.preprocessor = None
self.sequence_network = torch.nn.Sequential(OrderedDict([
("conv1", torch.nn.Conv1d(self.feature_size, 512, 5, dilation=1)),
("activation1", torch.nn.LeakyReLU(0.2)),
("norm1", torch.nn.BatchNorm1d(512)),
("conv2", torch.nn.Conv1d(512, 512, 3, dilation=2)),
("activation2", torch.nn.LeakyReLU(0.2)),
("norm2", torch.nn.BatchNorm1d(512)),
("conv3", torch.nn.Conv1d(512, 512, 3, dilation=3)),
("activation3", torch.nn.LeakyReLU(0.2)),
("norm3", torch.nn.BatchNorm1d(512)),
("conv4", torch.nn.Conv1d(512, 512, 1)),
("activation4", torch.nn.LeakyReLU(0.2)),
("norm4", torch.nn.BatchNorm1d(512)),
("conv5", torch.nn.Conv1d(512, 1536, 1)),
("activation5", torch.nn.LeakyReLU(0.2)),
("norm5", torch.nn.BatchNorm1d(1536))
]))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict([
("linear6", torch.nn.Linear(3072, 512))
]))
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict([
("activation6", torch.nn.LeakyReLU(0.2)),
("norm6", torch.nn.BatchNorm1d(512)),
("dropout6", torch.nn.Dropout(p=0.05)),
("linear7", torch.nn.Linear(512, 512)),
("activation7", torch.nn.LeakyReLU(0.2)),
("norm7", torch.nn.BatchNorm1d(512)),
("linear8", torch.nn.Linear(512, int(self.speaker_number)))
]))
self.sequence_network_weight_decay = 0.0002
self.before_speaker_embedding_weight_decay = 0.002
self.after_speaker_embedding_weight_decay = 0.002
else:
# Load Yaml configuration
with open(model_archi, 'r') as fh:
cfg = yaml.load(fh, Loader=yaml.FullLoader)
"""
Prepare Preprocessor
"""
self.preprocessor = None
if "preprocessor" in cfg:
if cfg['preprocessor']["type"] == "sincnet":
self.preprocessor = SincNet(
waveform_normalize=cfg['preprocessor']["waveform_normalize"],
sample_rate=cfg['preprocessor']["sample_rate"],
min_low_hz=cfg['preprocessor']["min_low_hz"],
min_band_hz=cfg['preprocessor']["min_band_hz"],
out_channels=cfg['preprocessor']["out_channels"],
kernel_size=cfg['preprocessor']["kernel_size"],
stride=cfg['preprocessor']["stride"],
max_pool=cfg['preprocessor']["max_pool"],
instance_normalize=cfg['preprocessor']["instance_normalize"],
activation=cfg['preprocessor']["activation"],
dropout=cfg['preprocessor']["dropout"]
)
self.feature_size = self.preprocessor.dimension
"""
Prepare sequence network
"""
# Get Feature size
if self.feature_size is None:
self.feature_size = cfg["feature_size"]
input_size = self.feature_size
# Get activation function
if cfg["activation"] == 'LeakyReLU':
self.activation = torch.nn.LeakyReLU(0.2)
elif cfg["activation"] == 'PReLU':
self.activation = torch.nn.PReLU()
elif cfg["activation"] == 'ReLU6':
self.activation = torch.nn.ReLU6()
else:
self.activation = torch.nn.ReLU()
# Create sequential object for the first part of the network
segmental_layers = []
for k in cfg["segmental"].keys():
if k.startswith("conv"):
segmental_layers.append((k, torch.nn.Conv1d(input_size,
cfg["segmental"][k]["output_channels"],
kernel_size=cfg["segmental"][k]["kernel_size"],
dilation=cfg["segmental"][k]["dilation"])))
input_size = cfg["segmental"][k]["output_channels"]
elif k.startswith("activation"):
segmental_layers.append((k, self.activation))
elif k.startswith('norm'):
segmental_layers.append((k, torch.nn.BatchNorm1d(input_size)))
self.sequence_network = torch.nn.Sequential(OrderedDict(segmental_layers))
self.sequence_network_weight_decay = cfg["segmental"]["weight_decay"]
"""
Prepapre last part of the network (after pooling)
"""
# Create sequential object for the second part of the network
input_size = input_size * 2
before_embedding_layers = []
for k in cfg["before_embedding"].keys():
if k.startswith("lin"):
if cfg["before_embedding"][k]["output"] == "speaker_number":
before_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
before_embedding_layers.append((k, torch.nn.Linear(input_size,
cfg["before_embedding"][k]["output"])))
input_size = cfg["before_embedding"][k]["output"]
elif k.startswith("activation"):
before_embedding_layers.append((k, self.activation))
elif k.startswith('norm'):
before_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
before_embedding_layers.append((k, torch.nn.Dropout(p=cfg["before_embedding"][k])))
self.before_speaker_embedding = torch.nn.Sequential(OrderedDict(before_embedding_layers))
self.before_speaker_embedding_weight_decay = cfg["before_embedding"]["weight_decay"]
# Create sequential object for the second part of the network
after_embedding_layers = []
for k in cfg["after_embedding"].keys():
if k.startswith("lin"):
if cfg["after_embedding"][k]["output"] == "speaker_number":
after_embedding_layers.append((k, torch.nn.Linear(input_size, self.speaker_number)))
else:
after_embedding_layers.append((k, torch.nn.Linear(input_size,
cfg["after_embedding"][k]["output"])))
input_size = cfg["after_embedding"][k]["output"]
elif k.startswith("activation"):
after_embedding_layers.append((k, self.activation))
elif k.startswith('norm'):
after_embedding_layers.append((k, torch.nn.BatchNorm1d(input_size)))
elif k.startswith('dropout'):
after_embedding_layers.append((k, torch.nn.Dropout(p=cfg["after_embedding"][k])))
self.after_speaker_embedding = torch.nn.Sequential(OrderedDict(after_embedding_layers))
self.after_speaker_embedding_weight_decay = cfg["after_embedding"]["weight_decay"]
[docs] def forward(self, x, is_eval=False):
"""
:param x:
:param is_eval:
:return:
"""
if self.preprocessor is not None:
x = self.preprocessor(x)
x = self.sequence_network(x)
# Mean and Standard deviation pooling
mean = torch.mean(x, dim=2)
std = torch.std(x, dim=2)
x = torch.cat([mean, std], dim=1)
x = self.before_speaker_embedding(x)
if is_eval:
return x
x = self.after_speaker_embedding(x)
return x
[docs]def xtrain(speaker_number,
dataset_yaml,
epochs=100,
lr=0.01,
model_yaml=None,
model_name=None,
tmp_model_name=None,
best_model_name=None,
multi_gpu=True,
clipping=False,
num_thread=1):
"""
:param speaker_number:
:param dataset_yaml:
:param epochs:
:param lr:
:param model_yaml:
:param model_name:
:param tmp_model_name:
:param best_model_name:
:param multi_gpu:
:param clipping:
:param num_thread:
:return:
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# If we start from an existing model
# if model_name is not None:
# # Load the model
# logging.critical(f"*** Load model from = {model_name}")
# checkpoint = torch.load(model_name)
# model = Xtractor(speaker_number, model_yaml)
# model.load_state_dict(checkpoint["model_state_dict"])
# else:
if True:
# Initialize a first model
if model_yaml is None:
model = Xtractor(speaker_number)
else:
model = Xtractor(speaker_number, model_yaml)
if torch.cuda.device_count() > 1 and multi_gpu:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
else:
print("Train on a single GPU")
model.to(device)
"""
Set the dataloaders according to the dataset_yaml
First we load the dataframe from CSV file in order to split it for training and validation purpose
Then we provide those two
"""
with open(dataset_yaml, "r") as fh:
dataset_params = yaml.load(fh, Loader=yaml.FullLoader)
df = pandas.read_csv(dataset_params["dataset_description"])
training_df, validation_df = train_test_split(df, test_size=dataset_params["validation_ratio"])
torch.manual_seed(dataset_params['seed'])
training_set = SideSet(dataset_yaml,
set_type="train",
dataset_df=training_df,
chunk_per_segment=dataset_params['chunk_per_segment'],
overlap=dataset_params['overlap'])
training_loader = DataLoader(training_set,
batch_size=dataset_params["batch_size"],
shuffle=True,
drop_last=True,
num_workers=num_thread)
validation_set = SideSet(dataset_yaml, set_type="validation", dataset_df=validation_df)
validation_loader = DataLoader(validation_set,
batch_size=dataset_params["batch_size"],
drop_last=True,
num_workers=num_thread)
"""
Set the training options
"""
if type(model) is Xtractor:
optimizer = torch.optim.SGD([
{'params': model.sequence_network.parameters(),
'weight_decay': model.sequence_network_weight_decay},
{'params': model.before_speaker_embedding.parameters(),
'weight_decay': model.before_speaker_embedding_weight_decay},
{'params': model.after_speaker_embedding.parameters(),
'weight_decay': model.after_speaker_embedding_weight_decay}],
lr=lr, momentum=0.9
)
else:
optimizer = torch.optim.SGD([
{'params': model.module.sequence_network.parameters(),
'weight_decay': model.module.sequence_network_weight_decay},
{'params': model.module.before_speaker_embedding.parameters(),
'weight_decay': model.module.before_speaker_embedding_weight_decay},
{'params': model.module.after_speaker_embedding.parameters(),
'weight_decay': model.module.after_speaker_embedding_weight_decay}],
lr=lr, momentum=0.9
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
best_accuracy = 0.0
best_accuracy_epoch = 1
for epoch in range(1, epochs + 1):
# Process one epoch and return the current model
model = train_epoch(model,
epoch,
training_loader,
optimizer,
dataset_params["log_interval"],
device=device,
clipping=clipping)
# Add the cross validation here
accuracy, val_loss = cross_validation(model, validation_loader, device=device)
logging.critical("*** Cross validation accuracy = {} %".format(accuracy))
# Decrease learning rate according to the scheduler policy
scheduler.step(val_loss)
print(f"Learning rate is {optimizer.param_groups[0]['lr']}")
# remember best accuracy and save checkpoint
is_best = accuracy > best_accuracy
best_accuracy = max(accuracy, best_accuracy)
if type(model) is Xtractor:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
else:
save_checkpoint({
'epoch': epoch,
'model_state_dict': model.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': best_accuracy,
'scheduler': scheduler
}, is_best, filename=tmp_model_name+".pt", best_filename=best_model_name+'.pt')
if is_best:
best_accuracy_epoch = epoch
logging.critical(f"Best accuracy {best_accuracy * 100.} obtained at epoch {best_accuracy_epoch}")
[docs]def train_epoch(model, epoch, training_loader, optimizer, log_interval, device, clipping=False):
"""
:param model:
:param epoch:
:param training_loader:
:param optimizer:
:param log_interval:
:param device:
:param clipping:
:return:
"""
model.train()
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
accuracy = 0.0
for batch_idx, (data, target) in enumerate(training_loader):
target = target.squeeze()
optimizer.zero_grad()
output = model(data.to(device))
loss = criterion(output, target.to(device))
loss.backward()
if clipping:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
if batch_idx % log_interval == 0:
batch_size = target.shape[0]
logging.critical('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.3f}'.format(
epoch, batch_idx + 1, training_loader.__len__(),
100. * batch_idx / training_loader.__len__(), loss.item(),
100.0 * accuracy.item() / ((batch_idx + 1) * batch_size)))
return model
[docs]def cross_validation(model, validation_loader, device):
"""
:param model:
:param validation_loader:
:param device:
:return:
"""
model.eval()
accuracy = 0.0
loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(validation_loader):
batch_size = target.shape[0]
target = target.squeeze()
output = model(data.to(device))
accuracy += (torch.argmax(output.data, 1) == target.to(device)).sum()
loss += criterion(output, target.to(device))
return 100. * accuracy.cpu().numpy() / ((batch_idx + 1) * batch_size), \
loss.cpu().numpy() / ((batch_idx + 1) * batch_size)
def extract_embeddings(idmap, model_filename, model_yaml, data_root_name , device):
# Create dataset to load the data
dataset = IdMapSet(data_root_name, idmap_name)
# Load the model
checkpoint = torch.load(model_filename)
model = Xtractor(speaker_number, model_archi=model_yaml)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.to(device)
# Get the size of embeddings to extract
name = list(model.before_speaker_embedding.state_dict().keys())[-1].split('.')[0] + '.weight'
emb_size = model.before_speaker_embedding.state_dict()[name].shape[0]
# Create the StatServer
embeddings = sidekit.StatServer()
embeddings.modelset = idmap.leftids
embeddings.segset = idmap.rightids
embeddings.start = idmap.start
embeddings.stop = idmap.stop
embeddings.stat0 = numpy.ones((embeddings.modelset.shape[0], 1))
embeddings.stat1 = numpy.ones((embeddings.modelset.shape[0], emb_size))
# Process the data
with torch.no_grad():
for idx, (data, mod, seg) in tqdm(enumerate(dataset)):
vec = model(data.to(device), is_eval=True)
current_idx = numpy.argwhere(numpy.logical_and(im.leftids == mod, im.rightids == seg))[0][0]
embeddings.stat1[current_idx, :] = vec.detach().cpu()
return embeddings