# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
#
# This file is part of S4D.
#
# SD4 is a python package for speaker diarization based on SIDEKIT.
# S4D home page: http://www-lium.univ-lemans.fr/s4d/
# SIDEKIT home page: http://www-lium.univ-lemans.fr/sidekit/
#
# S4D is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# S4D is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with SIDEKIT. If not, see <http://www.gnu.org/licenses/>.
"""
Diar is a class describing an audio/video segmentation file. A diarization
contains a list of segments. Where each row is segment composed of n values
identified by a attribut names.
The diarization file is the most important file in the toolkit. All programs
are driven by a diarization file and most of them generate a diarization
file (trainer generate gmm).
A diarization stores a list of ''segments'' composed of attributes.
A diarization could draw data from several shows. It is very useful in a batch
mode context (training of GMM, computing log likelihood ratio, cross-show
diarization, etc.).
Example
-------
>>> diarization[0] //get the first segment
['20041006_0700_0800_CLASSIQUE', 'Emmanuel_Cugny', 'speaker', 164, 1170]
>>> diarization[0]['show']
'20041006_0700_0800_CLASSIQUE'
>>> diarization[0]['cluster']
'Emmanuel_Cugny'
>>> diarization[0]['start']
164
>>> diarization[0]['stop']
1170
where:
* attribut 0 named ''show'': ''20041006_0700_0800_CLASSIQUE'' = the show speaker
* attribut 1 named ''cluster'' : ''Emmanuel_Cugny'' the speaker speaker
* attribut 2 named ''type'' : ''speaker'' contains the cluster type (speaker or head)
* attribut 3 named ''start'': ''164'' is the index of the first feature of the segment
* attribut 4 named ''stop'': ''1170'' is the index of the last feature of the segment
How to
------
* Read a diarization:
::
from s4d.diarization import Diar
diarization = Diar.read_seg('foo.seg') //LIUM Spk Diarization format
diarization = Diar.read_mdtm('foo.seg') //MDTM format
diarization = Diar.read_rttm('foo.seg') //RTTM format
diarization = Diar.read_uem('foo.seg') //UEM format
* Get a segment:
::
seg = diarization[0]
* Get a attribut of a segment
::
seg['cluster']
* Write a diarization:
::
diarization = Diar.write_seg('foo.seg', diarization) //LIUM format
* Add or remove an attribut:
::
diarization.add_attribut(speaker='gender', default=None) // add an attribut named 'gender, the default value is None
diarization.del_attribut('gender') // remove the attribut
* Sort a diarization:
::
diarization.sort()
* Create a new segment:
::
diarization.append(show='foo', cluster='speaker', start=0, stop=100)
Add all the segments of diar2 into diar1:
::
diar1.append_diar(diar2)
Modules
-------
"""
from sidekit.sidekit_wrappers import *
import re as regexp
import copy
import re
import logging
import os
from sidekit.bosaris.idmap import IdMap
import sys
from six import string_types
from s4d.utils import str2str_normalize
from sortedcontainers import SortedDict as dict
#try:
# from sortedcontainers import SortedDict as dict
#except ImportError:
# pass
[docs]class Diar():
"""
The diarization class.
:attr _attributes: a AttributeNames object storing the attribut definitions
:attr cluster_types: a list object
:attr segments: a list of Segment object
"""
def __init__(self):
self._attributes = AttributeNames()
self._attributes.initialize({'show': 0, 'cluster': 1, 'cluster_type': 2,
'start': 3, 'stop': 4},
['empty', 'empty', 'speaker', 0, 0])
self.cluster_types = ['speaker', 'head']
self.segments = list()
[docs] def copy_structure(self):
"""
Copy the internal structure of the diarization, ie the attribut names
and the cluster types. The data is not copy.
:return: a Diar object
"""
tmp_diarization = Diar()
tmp_diarization._attributes = copy.deepcopy(self._attributes)
tmp_diarization.cluster_types = copy.deepcopy(self.cluster_types)
return tmp_diarization
[docs] def del_all(self, attribute, value):
"""
Delete all segments satisfing the boolean expression [attribute = value]
:param attribute: speaker of the attribute to delete
:param value:
:return:
"""
lst = list()
for segment in self.segments:
if segment[attribute] != value:
lst.append(segment)
self.segments = lst
[docs] def remove_overlap(self):
"""
remove overlap zone
:return: a new diarization without overlap
"""
diarization_out = self.copy_structure()
#shows = self.unique('show')
shows = self.make_index(['show'])
for show in shows:
logging.info('rm overlap show: '+show)
diar_show = shows[show]
length = diar_show.last_feature_index()
cluster_list = diar_show.unique('cluster')
mat = numpy.zeros(length)
features_index = diar_show.features_by_cluster()
for i, cluster in enumerate(cluster_list):
mat[features_index[cluster]] += 1
uem = set(numpy.where(mat == 1)[0].tolist())
diar_tmp = self.copy_structure()
for i, cluster in enumerate(cluster_list):
lst = sorted(uem.intersection(set(features_index[cluster])))
if len(lst) > 0:
c = lst[0];
diar_tmp.append(show=show, start=c, stop=c+1, cluster=cluster)
l = 0
for j in range(1, len(lst)):
p = c
c = lst[j]
if c == p + 1:
l += 1
else:
diar_tmp[-1]['stop'] += l
diar_tmp.append(show=show, start=c, stop=c+1, cluster=cluster)
l = 0
if l > 0:
diar_tmp[-1]['stop'] += l
diarization_out.segments += diar_tmp.segments
return diarization_out
[docs] def filter(self, attribute, operator, value):
"""
build a new diarization whose segments satisfy the boolean expression
[attribute operator value]
:param attribute: a attribute speaker (str)
:param operator: a comperator opertor (> < >= <= in == !=)
:param value: the value (int, float, str, list...)
:return: a Diar object
"""
tmp_diarization = self.copy_structure()
tmp_diarization.segments = list()
if attribute == "length" or attribute == "duration":
str = "seg.duration() {:s} {}".format(operator, value)
elif isinstance(value, string_types):
str = "seg['{:s}'] {:s} '{:s}'".format(attribute, operator, value)
else:
str = "seg['{:s}'] {:s} {} ".format(attribute, operator, value)
# print(ch)
logging.debug(str)
for seg in self.segments:
# print(ch, seg.length())
if eval(str):
tmp_diarization.segments.append(copy.deepcopy(seg))
return tmp_diarization
[docs] def rename(self, attribute, old_values, new_value):
"""
Rename all values in list old_values into the new value new_value
:param attribute: speaker of the attribute
:param old_values: list of old values
:param new_value: new value
"""
for segment in self.segments:
if segment[attribute] in old_values or len(old_values) == 0 :
segment[attribute] = new_value
def _iofi(self, index, attributes, segment):
"""
recursive fonction to add a segment into the n level keys dictionary
:param index: dict object of level n
:param attributes: list of attribut attributes
:param segment: a segment
:return: a dictornary of level n that contains sub diarization. Segments
are not copy.
"""
# removes and gets the last attribut speaker
attribut = attributes.pop()
# takes the values of this attribut speaker
value = segment[attribut]
# if there is no more attribut attributes
if len(attributes) <= 0:
if value in index:
# add the segment to the list
index[value].append_seg(segment)
else:
# create a list and add the segment
ldiar = self.copy_structure()
ldiar.append_seg(segment)
index[value] = ldiar
return index
else:
# recursion to the level n-1 until attributes is empty
self._iofi(index[value], attributes, segment)
[docs] def make_index(self, attributes):
"""
Build a n level key dictionary (dictionary of dictionaries of
dictionaries...) based on Index.
Index is an implementation of perl's autovivification feature.
The values contains a list of row.
example :
d = make_index(['show', 'gender', 'cluster'])
print(d['show1']['M']['speaker'])
:param attributes: a list of attribut _attributes corresponding to the key indexes
:return: a dictionary of sub diarization. Segments are not copy.
"""
index = Index()
for segment in self.segments:
self._iofi(index, attributes[::-1], segment)
return index
[docs] def unique(self, attibute):
"""
:param attibute: the attibute of the attribut
:return: a list object of unique value of the attribut
"""
dic = dict()
lst = list()
for seg in self.segments:
dic[seg[attibute]] = 0
for value in dic.keys():
lst.append(value)
return lst
[docs] def sort(self, attributes=['show', 'start'], reverse=False):
"""
Sort the segments
:param attributes: a list of attribut names
:param reverse: if true, make a reverse sort
"""
attributes.reverse()
for attribute in attributes:
if attribute not in self._attributes:
raise Exception("This attribut don't exits : " + attribute)
self.segments = sorted(self.segments, key=lambda x: x[self._attributes[attribute]],
reverse=reverse)
[docs] def clear(self):
"""
remove all the segments
:return:
"""
self.segments = list()
[docs] def add_attribut(self, new_attribut, default=''):
"""
Add a attribut
:param new_attribut: the speaker of the new attribut
:param default: the default value of the attribut
"""
self._attributes.add(new_attribut, default)
for seg in self.segments:
seg.append(default)
[docs] def del_attribut(self, attribut):
"""
Delete a attribut
:param attribut: the speaker of the attribut to detele
"""
if attribut not in self._attributes:
raise Exception("This attribut don't exits : " + attribut)
else:
i = self._attributes[attribut]
for seg in self.segments:
del seg[i]
self._attributes.delete(attribut)
def _new_row(self, **kwargs):
"""
Create a new segment initialized with kwargs
:param kwargs: the values
:return:
"""
seg = Segment(self._attributes.defaults, self._attributes)
for key, value in kwargs.items():
seg[self._attributes[key]] = value
return seg
[docs] def append(self, **kwargs):
"""
Transforme a list of values into a segment and append the segmnt into
the existing segment list.
:param kwargs: the values
:return:
"""
self.segments.append(self._new_row(**kwargs))
[docs] def append_seg(self, segment):
"""
Append a Segment object into the existing segment list.
:param segment: a Segment object
"""
self.segments.append(segment)
[docs] def append_list(self, segment_lst):
"""
Append a list of segments into the existing segment segment_lst.
:param segment_lst: a list of segments
"""
self.segments += segment_lst
[docs] def append_diar(self, out_diarization):
"""
Append a diarization.
:param out_diarization: a diarization object
"""
self.segments += out_diarization.segments
[docs] def insert(self, i, **kwargs):
"""
Insert values into the list at offset index
:param i: This is the Index where the object obj need to be inserted.
:param kwargs: the values
"""
self.segments.insert(i, self._new_row(**kwargs))
def __iter__(self):
"""
This method is called when an iterator is required for a container.
:return: an iterator
"""
return self.segments.__iter__()
def __reversed__(self):
"""
Called (if present) by the reversed() built-in to implement reverse iteration.
:return: a Diar object
"""
return self.segments.__reversed__()
def __delitem__(self, index):
"""
Called to implement deletion of self[index]
:param index: a int
"""
del self.segments[index]
def __getitem__(self, index):
"""
Called to implement evaluation of self[index]
:param index: a int
:return: a Segment object
"""
return self.segments[index]
def __setitem__(self, index, value):
"""
Called to implement evaluation of self[index] = value
:param index: a int
:param value: a Segment object
"""
self.segments[index] = value
def __len__(self):
"""
:return: the number of segments
"""
return len(self.segments)
def __eq__(self, diarization): # real signature unknown
if len(self.segments) != len(diarization.segments):
return False
self.sort(attributes=['show', 'start', 'stop'])
diarization.sort(attributes=['show', 'start', 'stop'])
for i in range(len(self.segments)):
if self[i] != diarization[i]:
return False
return True
def __ne__(self, diarization): # real signature unknown
return not self.__eq__(diarization)
def __repr__(self):
"""
:return: a string version of the diarization
"""
string = ' attribut definition : ['
index = 0
lst = self._attributes.sorted()
#print(lst)
for attribute in lst:
string += "'" + attribute[0] + "', "
string = regexp.sub(', $', '', string) + ']\n'
for segment in self.segments:
line = ''
for attribute in segment:
line += attribute.__repr__() + ', '
string += ' row ' + string(index) + ': [' + regexp.sub(', $', '',
line) + ']\n'
index += 1
return '[\n' + string + ']'
def __add__(self, diarization):
diarization_copy = copy.deepcopy(self)
diarization_copy.segments += diarization.segments
return diarization_copy
def __iadd__(self, diarization):
self.segments += diarization.segments
return self
[docs] def id_map(self, id_attribut='cluster', show_attribut='show',
prefix_id_attrubut=None, suffix_show_attribut=None):
"""
Generate a IdMap object for the StatServer
:param id_attribut: speaker id_attribut attribut
:param show_attribut: show_attribut attribut
:param prefix_id_attrubut: prefix string of id_attribut
:param suffix_show_attribut: suffix string of id_attribut
:param out_diarization: a diarization object
:return: a IdMap object
"""
id_map = IdMap()
id_map.leftids = numpy.empty(len(self.segments), dtype="|O")
id_map.rightids = numpy.empty(len(self.segments), dtype="|O")
id_map.start = numpy.empty(len(self.segments), dtype="|O")
id_map.stop = numpy.empty(len(self.segments), dtype="|O")
i = 0
for segment in self.segments:
if prefix_id_attrubut is not None:
id_map.leftids[i] = segment[prefix_id_attrubut] + '/' + segment[id_attribut]
else:
id_map.leftids[i] = segment[id_attribut]
if suffix_show_attribut is not None:
id_map.rightids[i] = segment[show_attribut] + '/' + segment[suffix_show_attribut]
else:
id_map.rightids[i] = segment[show_attribut]
id_map.start[i] = segment['start']
id_map.stop[i] = segment['stop']
i += 1
return id_map
[docs] def features_by_cluster(self, show=None, maximum_length=None):
"""
Generate the indexes of a show
:param show: the speaker of the show
:param maximum_length: maximum length of the show
:return: a dict object (keys are the cluster_list)
"""
if show == None:
l = self.unique('show')
if len(l) > 1:
raise Exception('diarization address sevreal shows, set show parameter')
else:
show = l[0]
dic = dict()
for segment in self.segments:
if show == segment['show']:
cluster = segment['cluster']
start = segment['start']
stop = segment['stop']
if maximum_length is not None:
start = min(segment['start'], maximum_length)
stop = min(segment['stop'], maximum_length)
if cluster not in dic:
dic[cluster] = []
dic[cluster] += [i for i in range(start, stop)]
return dic
[docs] def features(self, show=None, maximum_length=None):
"""
Generate the index features of a show
:param show: a string corresponding to the speaker of the show
:param maximum_length: maximum length of the show
:return: a list object of indexes
"""
if show is None:
lst = self.unique('show')
if len(lst) > 1:
raise Exception('diarization address sevreal shows, set show parameter')
else:
show = lst[0]
lst = list()
for segment in self.segments:
if show == segment['show']:
start = segment['start']
stop = segment['stop']
if maximum_length is not None:
start = min(segment['start'], maximum_length)
stop = min(segment['stop'], maximum_length)
lst += [i for i in range(start, stop)]
return lst
[docs] def pack(self, epsilon=0):
"""
merge segments with a gap less than epsilon
:param epsilon: a int value
"""
index = self.make_index(['show', 'cluster'])
lst = list()
for show in index:
for cluster in index[show]:
index[show][cluster].sort(['start'])
diar = index[show][cluster]
i = 0
while i < len(diar.segments) - 1:
l = Segment.gap(diar.segments[i], diar.segments[i + 1]).duration()
if l <= epsilon:
diar.segments[i]['stop'] = max(diar.segments[i]['stop'],
diar.segments[i + 1]['stop'])
del diar.segments[i + 1]
else:
i += 1
lst += diar.segments
self.segments = lst
[docs] def pad(self, epsilon=0):
"""
Add epsilon frames to the start and stop of each segment
:param epsilon: the int value to remove
:return:
"""
self.sort(['start'])
i = 0
if len(self.segments) > 1:
self.segments[i]['stop'] = min(max(self.segments[i + 1]['start'] - (epsilon // 2), 0), self.segments[i]['stop'] + epsilon)
i += 1
while i < len(self.segments)-1:
self.segments[i]['start'] = max(self.segments[i - 1]['stop'], self.segments[i]['start'] - epsilon, 0)
self.segments[i]['stop'] = min(max(self.segments[i + 1]['start'] - (epsilon // 2), 0), self.segments[i]['stop'] + epsilon)
i += 1
[docs] def collar(self, epsilon=0, warning=False):
"""
Apply a collar on each segment. A collar is the no-score zone around
reference speaker segment boundaries. (Speaker Diarization output is
not evaluated within +/- collar seconds of a reference speaker segment
boundary.)
:param epsilon: the int value to add
"""
self.sort(['start'])
rm = False
for segment in self.segments:
segment['stop'] -= epsilon
segment['start'] += epsilon
if segment['start'] < 0:
segment['start'] = 0
if segment['start'] > segment['stop']:
segment['start'] = segment['stop']
rm = True
if warning:
logging.warning('no more segment: '+str(segment['start']-epsilon))
if rm:
self.segments = [seg for seg in self.segments if seg.duration() > 0]
[docs] def duration(self):
"""
:return: the sum of the segment duration
"""
l = 0
for segment in self.segments:
l += segment.duration()
return l
def last_feature_index(self):
last = 0
for segment in self.segments:
if segment['stop'] > last:
last = segment['stop']
return last
@classmethod
[docs] def read_seg(cls, filename, normalize_cluster=False):
"""
Read a segmentation file
:param filename: the str input filename
:param normalize_cluster: normalize the cluster speaker by removing upper
case and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
if not diarization._attributes.exist('env'):
diarization.add_attribut(new_attribut='env', default='U')
if not diarization._attributes.exist('channel'):
diarization.add_attribut(new_attribut='channel', default='U')
try:
for line in fic:
line = re.sub('\s+',' ',line)
line = line.strip()
# logging.debug(line)
if line.startswith('#') or line.startswith(';;'):
continue
# split line into fields
show, tmp, start, length, gender, channel, environment, name = line.split()
if normalize_cluster:
name = str2str_normalize(name)
# print(show, tmp, start, length, gender, channel, env, speaker)
diarization.append(show=show, cluster=name, start=int(start),
stop=int(length) + int(start), env=environment,
channel=channel,
gender=gender)
except Exception as e:
logging.error(sys.exc_info()[0])
# logging.error(line)
fic.close()
return diarization
@classmethod
[docs] def read_ctm(cls, filename, normalize_cluster=False):
"""
Read a segmentation file
:param filename: the str input filename
:param normalize_cluster: normalize the cluster by removing upper case
and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
diarization = Diar()
try:
for line in fic:
line = re.sub('\s+',' ',line)
line = line.strip()
# logging.debug(line)
if line.startswith('#') or line.startswith(';;'):
continue
# split line into fields
show, tmp, start, length, word = line.split()
if normalize_cluster:
word = str2str_normalize(word)
# print(show, tmp, start, length, gender, channel, env, speaker)
diarization.append(show=show, cluster=word, start=int(start),
stop=int(length) + int(start))
except Exception as e:
logging.error(sys.exc_info()[0])
# logging.error(line)
fic.close()
return diarization
@classmethod
[docs] def read_mdtm(cls, filename, normalize_cluster=False):
"""
Read a MDTM file
:param filename: the str input filename
:param normalize_cluster: normalize the cluster by removing upper case
and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
try:
for line in fic:
line = line.strip()
line = re.sub('\s+',' ',line)
logging.debug(line)
if line.startswith('#') or line.startswith(';;'):
continue
# split line into fields
show, tmp, start_str, length, t, score, gender, cluster = line.split()
start = int(round(float(start_str)*100, 0))
l = float(start_str)+float(length)
stop = int(round(l*100, 0))
if normalize_cluster:
cluster = str2str_normalize(cluster)
# print(show, tmp, start, length, gender, channel, env, speaker)
diarization.append(show=show, cluster=cluster, start=start,
stop=stop, gender=gender)
except Exception as e:
logging.error(sys.exc_info()[0])
logging.error(line+' nb:'+str(len(line.split())))
fic.close()
return diarization
@classmethod
[docs] def read_uem(cls, filename):
"""
Read a UEM file
:param filename: the str input filename
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
try:
name = "uem"
for line in fic:
line = re.sub('\s+',' ',line)
line = line.strip()
# logging.debug(line)
if line.startswith('#') or line.startswith(';;'):
continue
# split line into fields
show, tmp, start_str, stop_str = line.split()
start = int(round(float(start_str)*100, 0))
stop = int(round(float(stop_str)*100, 0))
# stop = start+int(round(float(length)*100, 0))
diarization.append(show=show, cluster=name, start=start, stop=stop)
except Exception as e:
logging.error(sys.exc_info()[0])
logging.error(line)
fic.close()
return diarization
@classmethod
[docs] def read_rttm(cls, filename, normalize_cluster=False):
"""
Read rttm file
:param filename: str input filename
:param normalize_cluster: normalize the cluster by removing upper case and accents
:return: a diarization object
"""
fic = open(filename, 'r', encoding="utf8")
diarization = Diar()
if not diarization._attributes.exist('gender'):
diarization.add_attribut(new_attribut='gender', default='U')
try:
for line in fic:
line = re.sub('\s+',' ',line)
line = line.strip()
if line.startswith('#') or line.startswith(';;'):
continue
# split line into fields
spk, show, tmp0, start_str, length, tmp1, tmp2, cluster, tmp3 = line.split()
if spk == "SPEAKER":
start = int(round(float(start_str)*100, 0))
stop = start+int(round(float(length)*100, 0))
if normalize_cluster:
cluster = str2str_normalize(cluster)
diarization.append(show=show, cluster=cluster, start=start, stop=stop)
except Exception as e:
logging.error(sys.exc_info()[0])
logging.error(line)
fic.close()
return diarization
@classmethod
[docs] def to_string_seg(cls, diar):
"""
transform a diarization into a string
:param diar: a diarization
:return: a string
"""
lst = []
for segment in diar:
gender = 'U'
if diar._attributes.exist('gender'):
gender = segment['gender']
env = 'U'
if diar._attributes.exist('env'):
env = segment['env']
channel = 'U'
if diar._attributes.exist('channel'):
channel = segment['channel']
lst.append('{:s} 1 {:d} {:d} {:s} {:s} {:s} {:s}\n'.format(
segment['show'], segment['start'], segment['stop'] - segment['start'], gender,
channel, env, segment['cluster']))
return lst
@classmethod
[docs] def intersection(cls, diarization1, diarization2):
"""
Compute the intersection between two diarization
:param diarization1: first diarization
:param diarization2: second diarization
:return: a diarization object
"""
diarization = Diar()
first = True
for segment1 in diarization1:
for segment2 in diarization2:
inter = Segment.intersection(segment1, segment2)
if inter is not None :
diarization.append_seg(inter)
if first:
diarization._attributes = segment1._attributes
first = False
return diarization
@classmethod
[docs] def write_seg(cls, filename, diarization):
"""
Write diarization to a segmentation file
:param filename: the str output filename
:param diarization: the diarization to write
"""
diarization.sort(['show', 'start'])
fic = open(filename, 'w', encoding="utf8")
for line in Diar.to_string_seg(diarization):
fic.write(line)
fic.close()
@classmethod
[docs] def write_lbl(cls, diarization, label_dir='', label_file_extension='.lbl'):
"""
Write diarization to label file
:param diarization: the diarization to write
:param label_dir: the string directory of the ouput filename
:param label_file_extension: the string extension of the output filename
"""
diarization.sort(['show', 'start'])
old_show = ''
fic = None
for segment in diarization:
if old_show != segment['show']:
if fic is not None:
fic.close()
filename = os.path.join(label_dir, segment['show']
+ label_file_extension)
fic = open(filename, 'w')
old_show = segment['show']
fic.write('{:d} {:d} {:s}\n'.format(
segment['start'], segment['stop'], segment['cluster']))
fic.close()
[docs]class Segment(list):
"""
Class to store the segment informations.
:attr _attributes: is the list of attribut names
:attr data: the data associated to each attribut
"""
def __init__(self, data, attributes):
"""
Called after the instance has been created (by __new__()), but before it is returned to the caller.
:param data: copy the row data
:param attributes: the names of the attributs
"""
list.__init__(self)
self._attributes = attributes
for item in data:
self.append(item)
def _get_attr(self, attr_name):
"""
Called to implement evaluation of self[attr_name].
:param attr_name: a string
:return: the value
"""
return self[self._attributes[attr_name]]
def _set_attr(self, attr_name, value):
"""
Called to implement assignment to self[attr_name].
:param attr_name: a str
:param value: the value to set
"""
self[self._attributes[attr_name]] = value
def __getitem__(self, index):
"""
Called to implement evaluation of self[index].
:param index: a int
:return: the value
"""
if isinstance(index, str):
return self._get_attr(index)
else:
return list.__getitem__(self, index)
def __setitem__(self, index, value):
"""
Called to implement assignment to self[index].
:param index: a int
:param value: the value to set
:return: the item
"""
if isinstance(index, str):
return self._set_attr(index, value)
else:
return list.__setitem__(self, index, value)
def __eq__(self, segment): # real signature unknown
if segment is not None:
l = len(segment)
if l != len(self):
return False
for i in range(l):
if self[i] != segment[i]:
return False
return True
else:
return False
def __ne__(self, segment): # real signature unknown
return not self.__eq__(segment)
[docs] def duration(self):
"""
:return: the duration of the segment
"""
return self['stop'] - self['start']
[docs] def seg_features(self, features):
"""
Given a FeatureServer, returns a list of feature index corresponding to
the segment.
:param features: a FeatureServer
:return: a list of int
"""
return features[self['start']:self['stop'], :]
@classmethod
[docs] def gap(cls, segment1, segment2):
"""
Returns the inter segment gap between 2 segments.
:param segment1: a Segment object
:param segment2: a Segment object
:return: a Segment object
Examples
--------
>>> from s4d.diarization import Diar, Segment
>>> diarization=Diar()
>>> diarization.append(show='empty', start=0, stop=100, cluster='spk1')
>>> diarization.append(show='empty', start=50, stop=150, cluster='spk2')
>>> s = Segment.intersection(diarization[0], diarization[1])
>>> s
['empty', 'spk1', 'speaker', 100, 50]
>>> s.duration()
- 50
>>> diarization.append(show='empty', start=200, stop=250, cluster='spk1')
>>> Segment.gap(diarization[0], diarization[2])
['empty', 'spk1', 'speaker', 100, 200]
"""
if segment1['show'] != segment2['show']:
raise Exception('not the same show')
segment = Segment(segment1, segment1._attributes)
segment['start'] = segment1['stop']
segment['stop'] = segment2['start']
return segment
@classmethod
[docs] def intersection(cls, segment1, segment2):
"""
Intersection between 2 segments. Return None if the intersection is empty.
:param segment1: a Segment object
:param segment2: a Segment object
:return: a Segment object
Examples
--------
>>> from s4d.diarization import Diar, Segment
>>> diarization=Diar()
>>> diarization.append(show='empty', start=0, stop=100, cluster='spk1')
>>> diarization.append(show='empty', start=50, stop=150, cluster='spk2')
>>> Segment.intersection(diarization[0], diarization[1])
['empty', 'spk1 / spk2', 'speaker', 50, 100]
>>> diarization.append(show='empty', start=50, stop=75, cluster='spk1')
>>> Segment.intersection(diarization[0], diarization[2])
['empty', 'spk1 / spk1', 'speaker', 50, 75]
>>> diarization.append(show='empty', start=200, stop=250, cluster='spk1')
>>> s = Segment.intersection(diarization[0], diarization[3])
>>> s is None
True
"""
if segment1['show'] != segment2['show']:
raise Exception(
'not the same show ' + segment1['show'] + ' != ' + segment2['show'])
segment = Segment(segment1, segment1._attributes)
segment['cluster'] += '##' + segment2['cluster']
segment['start'] = max(segment1['start'], segment2['start'])
segment['stop'] = min(segment1['stop'], segment2['stop'])
if segment.duration() > 0:
return segment
return None
@classmethod
[docs] def diff(cls, segment1, segment2):
"""
The difference between the two segments. Returns one or two segments and the
source of the new segment: 1 means segment1, 2 means segment2.
:param segment1: a Segment object
:param segment2: a Segment object
:return: a list of segments
"""
lst_row = list()
lst_source = list()
if segment1['show'] != segment2['show']:
return lst_row, lst_source
segment = Segment(segment1, segment1._attributes)
segment['start'] = min(segment1['start'], segment2['start'])
segment['stop'] = max(segment1['start'], segment2['start'])
if segment.duration() > 0:
lst_row.append(segment)
if segment['start'] == segment1['start']:
lst_source.append(1)
else:
lst_source.append(2)
segment = Segment(segment1, segment1._attributes)
segment['start'] = min(segment1['stop'], segment2['stop'])
segment['stop'] = max(segment1['stop'], segment2['stop'])
if segment.duration() > 0:
lst_row.append(segment)
if segment['stop'] == segment1['stop']:
lst_source.append(1)
else:
lst_source.append(2)
return lst_row, lst_source
@classmethod
[docs] def union(cls, segment1, segment2):
"""
Union between 2 segments.
:param segment1: a Segment object
:param segment2: a Segment object
:return: a Segment object
Examples
--------
>>> from s4d.diarization import Diar, Segment
>>> diarization=Diar()
>>> diarization.append(show='empty', start=0, stop=100, cluster='spk1')
>>> diarization.append(show='empty', start=50, stop=150, cluster='spk2')
>>> Segment.union(diarization[0], diarization[1])
['empty', 'spk1', 'speaker', 0, 150]
>>> diarization.append(show='empty', start=50, stop=75, cluster='spk1')
>>> Segment.union(diarization[0], diarization[2])
['empty', 'spk1', 'speaker', 0, 100]
>>> diarization.append(show='empty', start=200, stop=250, cluster='spk1')
>>> Segment.union(diarization[0], diarization[3])
['empty', 'spk1', 'speaker', 0, 250]
True
"""
if segment1['show'] != segment2['show']:
raise Exception('not the same show ' +
segment1['show'] + ' != ' + segment2['show'])
segment = Segment(segment1, segment1._attributes)
segment['start'] = min(segment1['start'], segment2['start'])
segment['stop'] = max(segment1['stop'], segment2['stop'])
return segment
[docs]class Index(dict):
"""Implementation of perl's autovivification feature.
Thanks to http://stackoverflow.com/questions/651794/whats-the-best-way-to-initialize-a-dict-of-dicts-in-python
"""
def __getitem__(self, item):
try:
return dict.__getitem__(self, item)
except KeyError:
value = self[item] = type(self)()
return value
[docs]class AttributeNames:
"""
Class AttributeNames defines a list of column names
"""
def __init__(self):
self.names = dict
self.defaults = list
[docs] def index_of(self, name):
"""
:param name: the speaker of the column
:return: the position (a int) of the speaker
"""
return self.names[name]
[docs] def exist(self, name):
"""
Test the existing value of speaker in the list of column names
:param name: the speaker of the column
:return: a boolean
"""
return name in self.names
[docs] def initialize(self, names, defaults):
"""
Initialaze AttributeNames object
:param names: a list of column names
:param defaults: the list of default values
"""
self.names = names
self.defaults = defaults
def __getitem__(self, index):
"""
Called to implement evaluation of self[index].
:param index: a int
:return: the value
"""
return self.names[index]
def __setitem__(self, index, value):
"""
Called to implement assignment to self[index].
:param index: a int
:param value: the value to set
:return: the item
"""
return self.names.__setitem__(index, value)
def __iter__(self):
"""
This method is called when an iterator is required for a container.
:return: a iterator on the column names
"""
return self.names.__iter__()
def __len__(self):
"""
Get the nomber of column names
:return:
"""
return self.defaults.__len__()
[docs] def sorted(self):
"""
sort the column names
:return:
"""
return sorted(self.names.items(), key=lambda x: x[1])
[docs] def add(self, name, default=''):
"""
a a column speaker
:param name: a str
:param default: the default value
"""
if name in self.names:
raise Exception('This attribut exits : ')
else:
self[name] = len(self.defaults)
self.defaults.append(default)
[docs] def delete(self, name):
"""
remove a column and the default value given it column speaker
:param name:
"""
if name not in self.names:
raise Exception("This attribut don't exits : " + name)
else:
i = self[name]
del self.defaults[i]
del self.names[name]
for k in self.names:
if self.names[k] > i:
self.names[k] -= 1
[docs]def rolling_window(a, window):
"""
Make an ndarray with a rolling window of the last dimension
Examples
--------
>>> x=numpy.arange(10).reshape((2,5))
>>> rolling_window(x, 3)
array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]],
[[5, 6, 7], [6, 7, 8], [7, 8, 9]]])
Calculate rolling mean of last dimension:
>>> numpy.mean(rolling_window(x, 3), -1)
array([[ 1., 2., 3.],
[ 6., 7., 8.]])
:param a: Array to add rolling window to
:param window: Size of rolling window
:return: Array that is a view of the original array with a added dimension
of size w.
"""
if window < 1:
raise (ValueError, "`window` must be at least 1.")
if window > a.shape[-1]:
raise (ValueError, "`window` is too long.")
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)