Train and eval you x-vector extractor¶
The different steps described in this tutorial are as follows:
Create a PyTorch DataSet
Create and train an X-vector extractor
Extract x-vectors
Train a PLDA model and evaluate on VoxCeleb1
Create DataSets for training¶
Training of an X-vector extractor: Xtractor requires the creation of two objects of type SideSet. SideSets inheritate from PyTorch DataSets.
In Sidekit we consider that training corpora are cut into two partds for training and validation of the network. For this reason, each corpora is associated to one single YAML that can be used to initialize two SideSets object (one for training, and the other one for validation).
DataSets are initialized by using a YAML file that must include the parameters described below.
- Note
A SideSet can be used to feed the networks with acoustic features (MFCC) but can also be used to provide raw waveforms.
Miscellaneous parameters¶
seed: 1234
log_interval: 10
dataset_description: voxceleb1_dev.csv
data_root_directory: /lium/corpus/base/voxceleb1/dev/wav/
data_file_extension: .wav
sample_rate: 16000
validation_ratio: 0.1
batch_size: 64
- seed
seed of the random initialization
- log_interval
Frequency to display the training Loss in number of batches
- dataset_description
the name of the file containing the dataset description in CSV format
- data_root_directory
Path where wavefiles are stored
- data_file_extension
extension of the audio file
- validation_ratio
Percentage of the data used for validation
- batch_size
Size of the batches
Training options¶
The “train” section of the YAML file defines the Training set options.
train:
duration: 4
chunk_per_segment: 1
overlap: 0.0
transformation:
pipeline: MFCC,CMVN,FrequencyMask(12-30),TemporalMask(70)
augmentation:
spec_aug: 0.5
temp_aug: 0.5
- duration
Duration of the speech chuncks given in seconds
- chunk_per_segment
Maximum number of chunks to select from every speech segment. -1 means selecting all possible segments
- overlap
overlap in percentagre between two possible successive chunks of audio data
- transformation
Section that describes the transformations applied to the audio chuncks
- pipeline
string that gives the sequence of transformations to apply
- augmentation
Data augmentation can be applied on-the-fly, the chosen augmentation processes is described in this section. Some parameters refer to transformation applied for the case of spectral and temporal augmentation.
- spec_aug
Apply spectral augmentation (a band of frequency coefficient chosen randomlly is masked). The given parameter is a percentage of chuncks that are modified.
- temp_augm
Apply temporal augmentation (a temporal band chosen randomlly is masked). The given parameter is a percentage of chuncks that are modified.
Validation options¶
Section that is similar to the training. In this example we see that datra augmentatioin is not applied furing validation. Duration of the chunks can be let empty to process the entire speech segments.
eval:
duration: 4
transformation:
pipeline: MFCC,CMVN
spec_aug: 0.5
temp_aug: 0.5
augmentation:
spec_aug: 0.0
temp_aug: 0.0
Instantiate a SideSet¶
Given a CSV file describing the corpora and a YAML file similar to the one described above, one can instantiate a SideSet as follow:
training_set = SideSet(data_set_yaml=dataset_yaml,
set_type="train",
dataset_df=training_df,
chunk_per_segment=dataset_params['chunk_per_segment'],
overlap=dataset_params['overlap'])
- data_set_yaml
The YAML file that describes the DataSet to create
- set_type
train or validation to apply the chosen configuration
- dataset_df
Optional: a pandas.DataFrame that describe the coropora of portion of the corpora to use (see the rest of the tutorial for examples)
- chunk_per_segment
Number of audio chunks to select for each audio segment. Integer. This value can be set to -1 to select all possible chunks
- overlap
allowed overlap between consecutive chunks
Download standard corpora descriptions¶
VoxCeleb1 development data
VoxCeleb2 development data
ALLIES development data
Create and train the Xtractor¶
X-vectors in SIDEKIT are extracted using an Xtractor object. An Extraxctor inheritates from the torch.module class and is a stack of 3 or 4 torch.nn.sequential.
All parts of the Xtractor are described by a section of the loaded yaml file as follows:
Miscellaneous parameters¶
feature_size: 30
activation: LeakyReLU
- feature_size
size of the input acoustic features, note that this size is not used when the Xtractor includes a preprocessor block.
- activation
all activations are the same in the current implementation, could be “LeakyReLU”, “PReLU”, “ReLU6” and by default “ReLU”
Process feature sequences¶
The first block of the network is made of Convolutional layers to process a sequence of features. The artchitecture is given by a succession of layers of type “Conv1d”, “activation” or “BatchNorm”. The type of layer is determine by the start of the of the layer’s name.
- conv*
will be a Conv1d layer
- norm*
will be a BatchNorm layer
- activation*
will add an activation function
A weight decay can be used for regulartization in order to prevent overfitting.
segmental:
weight_decay: 0.0002
conv1:
output_channels: 512
kernel_size: 5
dilation: 1
activation1: True
norm1: 512
conv2:
output_channels: 512
kernel_size: 3
dilation: 2
activation2: True
norm2: 512
conv3:
output_channels: 512
kernel_size: 3
dilation: 3
activation3: True
norm3: 512
conv4:
output_channels: 512
kernel_size: 1
dilation: 1
activation4: True
norm4: 512
conv5:
output_channels: 1536
kernel_size: 1
dilation: 1
activation5: True
norm5: 1536
Processing before embedding¶
This section defines the block that occurs after the pooling. In this version, the only possible pooling is a mean and standard deviation, thus the input size of this block is twice the size of the output of the previous one. the output of this block is the so called x-vector.
This block can include Linear layers activation, dropout and BatchNorm. Again, the type of layer is defined by the name of the layer:
- lin*
will be a Linear layer
- norm*
will be a BatchNorm layer
- activation*
will add an activation function
- dropout
will adda dropout layer
before_embedding:
weight_decay: 0.0002
linear6:
output: 512
Processing after embedding¶
This section defines the block that occurs after the extraction of the x-vector.
This block can include Linear layers activation, dropout and BatchNorm. Again, the type of layer is defined by the name of the layer:
- lin*
will be a Linear layer
- norm*
will be a BatchNorm layer
- activation*
will add an activation function
- dropout
will adda dropout layer
after_embedding:
weight_decay: 0.0002
activation6: True
norm6: 512
dropout6: 0.05
linear7:
output: 512
activation7: True
norm7: 512
linear8:
output: speaker_number
Optional: add a preprocessor¶
Instead of processing acoustic features, the Xtractor can be fed with raw Waveform when including a preprocessor block. The current version of the Xtractor only allows a standard SincNet network.
preprocessor:
type: sincnet
waveform_normalize: True
sample_rate: 16000
min_low_hz: 50
min_band_hz: 50
out_channels: [80, 60, 60]
kernel_size: [251, 5, 5]
stride: [1, 1, 1]
max_pool: [3, 3, 3]
instance_normalize: True
activation: leaky_relu
dropout: 0.0
Instantiate an Xtractor¶
Creating an Xtractor objet is very simple and only requires to define the number of output classes ‘speaker_number’ and possibly a customized architecture defined by a YAML file as described above. In case no YAML file is given during initialization, the architecture of the Xtractor is the described above.
# Create a default architecture with a custom number of classes
model = Xtractor(speaker_number)
# Create an Xtractor according the architecture described in a YAML file
model = Xtractor(speaker_number, model_yaml)
- speaker_number
is the number of classes (speakers)
- model_yaml
the name of the YAML file describing the architecture
Training the Xtractor¶
The simpler way to create, train and save an Xtractor is to call the xtrain function from sidekit.nnet.xvectort module.
sidekit.nnet.xtrain(speaker_number=args.class_number,
dataset_yaml=args.dataset,
epochs=args.epochs,
lr=args.lr,
model_yaml=args.architecture,
tmp_model_name=args.outputname,
best_model_name=args.outputbestname,
multi_gpu=args.multi_gpu=="true",
clipping=False,
num_thread=args.num_processes
)
- speaker_number
The number of output classes (speakers)
- dataset_yaml
the YAML file used to describre the SideSet (training and validation)
- epochs
Number of epochs to run, default is 100
- lr
learning_rate, default is 0.01
- model_yaml
YAML file to describe the model architecture (can be None for default architecture)
- model_name
optional. Name of a previous checkpoint file to start from
- tmp_model_name
name of the checkoint file used to save the model after each iteration Note that Xtractors, even trained with torch.nn.DataParallel will be saved in single GPU mode.
- best_model_name
name of the checkpoint file to save the current best model version. Updated each time the validation loss is lower than the best one.
- multi_gpu
Boolean, if False, force the use of a single GPU, default is True and makes use of torch.nn.DataParallel
- clipping
Boolean. If True, gradient is clipped to 1
- num_thread
Number of process used by the DataLoaders, default is 1
Extract x-vectors¶
Once your Xtractor has been trained you can now extract x-vectors. The process is fully managed with sidekit.bosaris.IdMap to be complient to SIDEKIT philosophy.
An IdMap is an object containing the folloqing information:
- leftids
typically the name of the class the audio chunk belongs to
- rightids
typically the name of the audio file to load the signal from
- start
the start time ofg the audio chunk, given as frame number (or number of samples)
- stop
the end time of the audio chunk, given as frame number (or number of samples)
import sidekit
import torch
from sidekit.nnet import extract_embeddings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
idmap_name = "VoxCeleb1_enrol_idmap.h5"
# Process the data and return a StatServer
xv_stat = extract_embeddings(idmap_name=idmap_name,
speaker_number=1211,
model_filename="best_model_newXV.pt",
model_yaml="archi.yaml",
data_root_name="/lium/corpus/base/voxceleb1/test/wav/" ,
device=device,
transform_pipeline="MFCC,CMVN")
# Save the StatServer in HDF5 format
xv_stat.write("VoxCeleb1_enrol_xvectors.h5")
IdMap in HDF5 format can be downloaded from here for the standard VoxCeleb1 evaluation.
training_idmap
enrolment_idmap
test_idmap