Train and eval you x-vector extractor

This tutorial is about training and running an X-vector
extractor on the VoxCeleb dataset and evaluating using the
standard protocol.

We assume here that you’ve downloaded VoxCeleb1 data
from the official website

The different steps described in this tutorial are as follows:

  • Create a PyTorch DataSet

  • Create and train an X-vector extractor

  • Extract x-vectors

  • Train a PLDA model and evaluate on VoxCeleb1

Create DataSets for training

Training of an X-vector extractor: Xtractor requires the creation of two objects of type SideSet. SideSets inheritate from PyTorch DataSets.

In Sidekit we consider that training corpora are cut into two partds for training and validation of the network. For this reason, each corpora is associated to one single YAML that can be used to initialize two SideSets object (one for training, and the other one for validation).

DataSets are initialized by using a YAML file that must include the parameters described below.

Note

A SideSet can be used to feed the networks with acoustic features (MFCC) but can also be used to provide raw waveforms.

Miscellaneous parameters

seed: 1234
log_interval: 10

dataset_description: voxceleb1_dev.csv
data_root_directory: /lium/corpus/base/voxceleb1/dev/wav/
data_file_extension: .wav
sample_rate: 16000

validation_ratio: 0.1
batch_size: 64
seed

seed of the random initialization

log_interval

Frequency to display the training Loss in number of batches

dataset_description

the name of the file containing the dataset description in CSV format

data_root_directory

Path where wavefiles are stored

data_file_extension

extension of the audio file

validation_ratio

Percentage of the data used for validation

batch_size

Size of the batches

Training options

The “train” section of the YAML file defines the Training set options.

train:
    duration: 4
    chunk_per_segment: 1
    overlap: 0.0

    transformation:
        pipeline: MFCC,CMVN,FrequencyMask(12-30),TemporalMask(70)

    augmentation:
        spec_aug: 0.5
        temp_aug: 0.5
duration

Duration of the speech chuncks given in seconds

chunk_per_segment

Maximum number of chunks to select from every speech segment. -1 means selecting all possible segments

overlap

overlap in percentagre between two possible successive chunks of audio data

transformation

Section that describes the transformations applied to the audio chuncks

pipeline

string that gives the sequence of transformations to apply

augmentation

Data augmentation can be applied on-the-fly, the chosen augmentation processes is described in this section. Some parameters refer to transformation applied for the case of spectral and temporal augmentation.

spec_aug

Apply spectral augmentation (a band of frequency coefficient chosen randomlly is masked). The given parameter is a percentage of chuncks that are modified.

temp_augm

Apply temporal augmentation (a temporal band chosen randomlly is masked). The given parameter is a percentage of chuncks that are modified.

Validation options

Section that is similar to the training. In this example we see that datra augmentatioin is not applied furing validation. Duration of the chunks can be let empty to process the entire speech segments.

eval:

    duration: 4

    transformation:
        pipeline: MFCC,CMVN
        spec_aug: 0.5
        temp_aug: 0.5

    augmentation:
        spec_aug: 0.0
        temp_aug: 0.0

Instantiate a SideSet

Given a CSV file describing the corpora and a YAML file similar to the one described above, one can instantiate a SideSet as follow:

training_set = SideSet(data_set_yaml=dataset_yaml,
                   set_type="train",
                   dataset_df=training_df,
                   chunk_per_segment=dataset_params['chunk_per_segment'],
                   overlap=dataset_params['overlap'])
data_set_yaml

The YAML file that describes the DataSet to create

set_type

train or validation to apply the chosen configuration

dataset_df

Optional: a pandas.DataFrame that describe the coropora of portion of the corpora to use (see the rest of the tutorial for examples)

chunk_per_segment

Number of audio chunks to select for each audio segment. Integer. This value can be set to -1 to select all possible chunks

overlap

allowed overlap between consecutive chunks

Download standard corpora descriptions

  • VoxCeleb1 development data

  • VoxCeleb2 development data

  • ALLIES development data

Create and train the Xtractor

X-vectors in SIDEKIT are extracted using an Xtractor object. An Extraxctor inheritates from the torch.module class and is a stack of 3 or 4 torch.nn.sequential.

All parts of the Xtractor are described by a section of the loaded yaml file as follows:

Miscellaneous parameters

feature_size: 30
activation: LeakyReLU
feature_size

size of the input acoustic features, note that this size is not used when the Xtractor includes a preprocessor block.

activation

all activations are the same in the current implementation, could be “LeakyReLU”, “PReLU”, “ReLU6” and by default “ReLU”

Process feature sequences

The first block of the network is made of Convolutional layers to process a sequence of features. The artchitecture is given by a succession of layers of type “Conv1d”, “activation” or “BatchNorm”. The type of layer is determine by the start of the of the layer’s name.

conv*

will be a Conv1d layer

norm*

will be a BatchNorm layer

activation*

will add an activation function

A weight decay can be used for regulartization in order to prevent overfitting.

segmental:

    weight_decay: 0.0002

    conv1:
        output_channels: 512
        kernel_size: 5
        dilation: 1
    activation1: True
    norm1: 512

    conv2:
        output_channels: 512
        kernel_size: 3
        dilation: 2
    activation2: True
    norm2: 512

    conv3:
        output_channels: 512
        kernel_size: 3
        dilation: 3
    activation3: True
    norm3: 512

    conv4:
        output_channels: 512
        kernel_size: 1
        dilation: 1
    activation4: True
    norm4: 512

    conv5:
        output_channels: 1536
        kernel_size: 1
        dilation: 1
    activation5: True
    norm5: 1536

Processing before embedding

This section defines the block that occurs after the pooling. In this version, the only possible pooling is a mean and standard deviation, thus the input size of this block is twice the size of the output of the previous one. the output of this block is the so called x-vector.

This block can include Linear layers activation, dropout and BatchNorm. Again, the type of layer is defined by the name of the layer:

lin*

will be a Linear layer

norm*

will be a BatchNorm layer

activation*

will add an activation function

dropout

will adda dropout layer

before_embedding:

    weight_decay: 0.0002

    linear6:
        output: 512

Processing after embedding

This section defines the block that occurs after the extraction of the x-vector.

This block can include Linear layers activation, dropout and BatchNorm. Again, the type of layer is defined by the name of the layer:

lin*

will be a Linear layer

norm*

will be a BatchNorm layer

activation*

will add an activation function

dropout

will adda dropout layer

after_embedding:

    weight_decay: 0.0002

    activation6: True
    norm6: 512
    dropout6: 0.05

    linear7:
        output: 512
    activation7: True
    norm7: 512

    linear8:
        output: speaker_number

Optional: add a preprocessor

Instead of processing acoustic features, the Xtractor can be fed with raw Waveform when including a preprocessor block. The current version of the Xtractor only allows a standard SincNet network.

preprocessor:
    type: sincnet
    waveform_normalize: True
    sample_rate: 16000
    min_low_hz: 50
    min_band_hz: 50
    out_channels: [80, 60, 60]
    kernel_size: [251, 5, 5]
    stride: [1, 1, 1]
    max_pool: [3, 3, 3]
    instance_normalize: True
    activation: leaky_relu
    dropout: 0.0

Instantiate an Xtractor

Creating an Xtractor objet is very simple and only requires to define the number of output classes ‘speaker_number’ and possibly a customized architecture defined by a YAML file as described above. In case no YAML file is given during initialization, the architecture of the Xtractor is the described above.

# Create a default architecture with a custom number of classes
model = Xtractor(speaker_number)

# Create an Xtractor according the architecture described in a YAML file
model = Xtractor(speaker_number, model_yaml)
speaker_number

is the number of classes (speakers)

model_yaml

the name of the YAML file describing the architecture

Training the Xtractor

The simpler way to create, train and save an Xtractor is to call the xtrain function from sidekit.nnet.xvectort module.

sidekit.nnet.xtrain(speaker_number=args.class_number,
                dataset_yaml=args.dataset,
                epochs=args.epochs,
                lr=args.lr,
                model_yaml=args.architecture,
                tmp_model_name=args.outputname,
                best_model_name=args.outputbestname,
                multi_gpu=args.multi_gpu=="true",
                clipping=False,
                num_thread=args.num_processes
               )
speaker_number

The number of output classes (speakers)

dataset_yaml

the YAML file used to describre the SideSet (training and validation)

epochs

Number of epochs to run, default is 100

lr

learning_rate, default is 0.01

model_yaml

YAML file to describe the model architecture (can be None for default architecture)

model_name

optional. Name of a previous checkpoint file to start from

tmp_model_name

name of the checkoint file used to save the model after each iteration Note that Xtractors, even trained with torch.nn.DataParallel will be saved in single GPU mode.

best_model_name

name of the checkpoint file to save the current best model version. Updated each time the validation loss is lower than the best one.

multi_gpu

Boolean, if False, force the use of a single GPU, default is True and makes use of torch.nn.DataParallel

clipping

Boolean. If True, gradient is clipped to 1

num_thread

Number of process used by the DataLoaders, default is 1

Extract x-vectors

Once your Xtractor has been trained you can now extract x-vectors. The process is fully managed with sidekit.bosaris.IdMap to be complient to SIDEKIT philosophy.

An IdMap is an object containing the folloqing information:

leftids

typically the name of the class the audio chunk belongs to

rightids

typically the name of the audio file to load the signal from

start

the start time ofg the audio chunk, given as frame number (or number of samples)

stop

the end time of the audio chunk, given as frame number (or number of samples)

import sidekit
import torch

from sidekit.nnet import extract_embeddings


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



idmap_name = "VoxCeleb1_enrol_idmap.h5"

# Process the data and return a StatServer
xv_stat = extract_embeddings(idmap_name=idmap_name,
                             speaker_number=1211,
                             model_filename="best_model_newXV.pt",
                             model_yaml="archi.yaml",
                             data_root_name="/lium/corpus/base/voxceleb1/test/wav/" ,
                             device=device,
                             transform_pipeline="MFCC,CMVN")

# Save the StatServer in HDF5 format
xv_stat.write("VoxCeleb1_enrol_xvectors.h5")

IdMap in HDF5 format can be downloaded from here for the standard VoxCeleb1 evaluation.

  • training_idmap

  • enrolment_idmap

  • test_idmap

Evaluate using a PLDA model