c is an open-source package to replicate the experiments in [^1].
The main dependencies of learnable-strf are :
- NumPy (>= 1.10)
- pytorch (== 1.3.0)
- nnaudio (== 1.3.0)
- pyannote.core (>= 4.1)
- pyannote.audio (== 2.0a1+60.gc683897) (installed with the shell)
- pyannote.database (== 4.0.1+5.g8394991)
- pyannote.pipeline (== 1.5)
- pyannote.metrics (== 2.3)
The Learnable STRF can be easily implemented in pytorch and is inspired by the implementation from this package
We used the nnAudio package to obtain the log Mel Filterbanks.
from typing import Optional
from typing import Text
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from nnAudio import Spectrogram
from torch.nn.utils.rnn import PackedSequence
from torch.nn.modules.conv import _ConvNd
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
class STRFConv2d(_ConvNd):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
padding_mode='zeros',
device=None,
n_features=64):
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(STRFConv2d,
self).__init__(in_channels, out_channels,
kernel_size, stride, padding, dilation, False,
_pair(0), groups, bias, padding_mode)
self.n_features = n_features
self.theta = np.random.vonmises(0, 0, (out_channels, in_channels))
self.gamma = np.random.vonmises(0, 0, (out_channels, in_channels))
self.psi = np.random.vonmises(0, 0, (out_channels, in_channels))
self.gamma = nn.Parameter(torch.Tensor(self.gamma))
self.psi = nn.Parameter(torch.Tensor(self.psi))
self.freq = (np.pi / 2) * 1.41**(
-np.random.uniform(0, 5, size=(out_channels, in_channels)))
self.freq = nn.Parameter(torch.Tensor(self.freq))
self.theta = nn.Parameter(torch.Tensor(self.theta))
self.sigma_x = 2 * 1.41**(np.random.uniform(
0, 6, (out_channels, in_channels)))
self.sigma_x = nn.Parameter(torch.Tensor(self.sigma_x))
self.sigma_y = 2 * 1.41**(np.random.uniform(
0, 6, (out_channels, in_channels)))
self.sigma_y = nn.Parameter(torch.Tensor(self.sigma_y))
self.f0 = torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0]
self.t0 = torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0]
def forward(self, sequences, use_real=True):
packed_sequences = isinstance(sequences, PackedSequence)
if packed_sequences:
device = sequences.data.device
else:
device = sequences.device
sequences = sequences.reshape(
sequences.size(0), 1, self.n_features, -1)
grid = [
torch.linspace(-self.f0 + 1, self.f0, self.kernel_size[0]),
torch.linspace(-self.t0 + 1, self.t0, self.kernel_size[1])
]
f, t = torch.meshgrid(grid)
f = f.to(device)
t = t.to(device)
weight = torch.empty(self.weight.shape, requires_grad=False)
for i in range(self.out_channels):
for j in range(self.in_channels):
sigma_x = self.sigma_x[i, j].expand_as(t)
sigma_y = self.sigma_y[i, j].expand_as(t)
freq = self.freq[i, j].expand_as(t)
theta = self.theta[i, j].expand_as(t)
gamma = self.gamma[i, j].expand_as(t)
psi = self.psi[i, j].expand_as(t)
rotx = t * torch.cos(theta) + f * torch.sin(theta)
roty = -t * torch.sin(theta) + f * torch.cos(theta)
rot_gamma = t * torch.cos(gamma) + f * torch.sin(gamma)
g = torch.zeros(t.shape)
g = torch.exp(-0.5 * ((f**2) / (sigma_x + 1e-3)**2 +
(t**2) / (sigma_y + 1e-3)**2))
if use_real:
g = g * torch.cos(freq * rot_gamma)
else:
g = g * torch.sin(freq * rot_gamma)
g = g / (2 * np.pi * sigma_x * sigma_y)
weight[i, j] = g
self.weight.data[i, j] = g
weight = weight.to(device)
return F.conv2d(sequences, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
All the experiments for Speech Activity Detection are run with the pyannote
ecosystem.
The AMI corpus can be obtained freely from https://groups.inf.ed.ac.uk/ami/corpus/.
The CHIME5 corpus can be obtained freely from url
The protocol databases for the train/dev/test are AMI.SpeakerDiarization.MixHeadset and CHiME5.SpeakerDiarization.U01 and can be obtained via pyannote.database and the pip
commands for AMI and for [CHIME5] it is required the following lines to your .pyannote/database.yml
.
CHiME5:
SpeakerDiarization:
U01:
train:
annotation: /export/fs01/jsalt19/databases/CHiME5/train/allU01_train.rttm
annotated: /export/fs01/jsalt19/databases/CHiME5/train/allU01_train.uem
development:
annotation: /export/fs01/jsalt19/databases/CHiME5/dev/allU01_dev.rttm
annotated: /export/fs01/jsalt19/databases/CHiME5/dev/allU01_dev.uem
test:
annotation: /export/fs01/jsalt19/databases/CHiME5/test/allU01_test.rttm
annotated: /export/fs01/jsalt19/databases/CHiME5/test/allU01_test.uem
We followed the protocol from JM Coria et al., and injected the network STRFTDNN instead of SincTDNN.
We followed the protocol from Arnault et al., and just modified the Pann architecture by injecting the STRFConv2D on top of the Mel Filterbanks.
The models to run the Speech Activity Detection and Speaker Identification are in the file models.py
. This file replaces the models.py in the pyannote.audio package to use the Learnable STRF
We are very grateful to authors from Pyannote, nnAudio, urban sound sound package, Theunissen's group, Shamma's group for the open source packages and datasets which made possible this work.
[1] Riad R., Karadyi J., Bachoud-Lévi AC., Dupoux, E. Learning spectro-temporal representations of complex sounds with parameterized neural networks. The Journal of the Acoustical Society of America