Skip to content

Commit

Permalink
feat: add DER torchmetrics (#909)
Browse files Browse the repository at this point in the history
Co-authored-by: Hervé BREDIN <[email protected]>
  • Loading branch information
FrenchKrab and hbredin authored Mar 14, 2022
1 parent 4a618ee commit 1dc198d
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 0 deletions.
36 changes: 36 additions & 0 deletions pyannote/audio/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# MIT License
#
# Copyright (c) 2022- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from .audio.diarization_error_rate import (
DiarizationErrorRate,
FalseAlarmRate,
MissedDetectionRate,
SpeakerConfusionRate,
)

__all__ = [
"DiarizationErrorRate",
"FalseAlarmRate",
"MissedDetectionRate",
"SpeakerConfusionRate",
]
36 changes: 36 additions & 0 deletions pyannote/audio/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# MIT License
#
# Copyright (c) 2022- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from .diarization_error_rate import (
DiarizationErrorRate,
FalseAlarmRate,
MissedDetectionRate,
SpeakerConfusionRate,
)

__all__ = [
"DiarizationErrorRate",
"SpeakerConfusionRate",
"MissedDetectionRate",
"FalseAlarmRate",
]
119 changes: 119 additions & 0 deletions pyannote/audio/torchmetrics/audio/diarization_error_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# MIT License
#
# Copyright (c) 2022- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
from torchmetrics import Metric

from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import (
_der_compute,
_der_update,
)


class DiarizationErrorRate(Metric):
"""Diarization error rate
Parameters
----------
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
Notes
-----
While pyannote.audio conventions is to store speaker activations with
(batch_size, num_frames, num_speakers)-shaped tensors, this torchmetrics metric
expects them to be shaped as (batch_size, num_speakers, num_frames) tensors.
"""

higher_is_better = False
is_differentiable = False

def __init__(self, threshold: float = 0.5):
super().__init__()

self.threshold = threshold

self.add_state("false_alarm", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state(
"missed_detection", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
self.add_state(
"speaker_confusion", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
self.add_state("speech_total", default=torch.tensor(0.0), dist_reduce_fx="sum")

def update(
self,
preds: torch.Tensor,
target: torch.Tensor,
) -> None:
"""Compute and accumulate components of diarization error rate
Parameters
----------
preds : torch.Tensor
(batch_size, num_speakers, num_frames)-shaped continuous predictions.
target : torch.Tensor
(batch_size, num_speakers, num_frames)-shaped (0 or 1) targets.
Returns
-------
false_alarm : torch.Tensor
missed_detection : torch.Tensor
speaker_confusion : torch.Tensor
speech_total : torch.Tensor
Diarization error rate components accumulated over the whole batch.
"""

false_alarm, missed_detection, speaker_confusion, speech_total = _der_update(
preds, target, threshold=self.threshold
)
self.false_alarm += false_alarm
self.missed_detection += missed_detection
self.speaker_confusion += speaker_confusion
self.speech_total += speech_total

def compute(self):
return _der_compute(
self.false_alarm,
self.missed_detection,
self.speaker_confusion,
self.speech_total,
)


class SpeakerConfusionRate(DiarizationErrorRate):
def compute(self):
# TODO: handler corner case where speech_total == 0
return self.speaker_confusion / self.speech_total


class FalseAlarmRate(DiarizationErrorRate):
def compute(self):
# TODO: handler corner case where speech_total == 0
return self.false_alarm / self.speech_total


class MissedDetectionRate(DiarizationErrorRate):
def compute(self):
# TODO: handler corner case where speech_total == 0
return self.missed_detection / self.speech_total
21 changes: 21 additions & 0 deletions pyannote/audio/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# MIT License
#
# Copyright (c) 2022- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
21 changes: 21 additions & 0 deletions pyannote/audio/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# MIT License
#
# Copyright (c) 2022- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
128 changes: 128 additions & 0 deletions pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# MIT License
#
# Copyright (c) 2022- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


from typing import Tuple

import torch

from pyannote.audio.utils.permutation import permutate


def _der_update(
preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute components of diarization error rate
Parameters
----------
preds : torch.Tensor
(batch_size, num_speakers, num_frames)-shaped continuous predictions.
target : torch.Tensor
(batch_size, num_speakers, num_frames)-shaped (0 or 1) targets.
threshold : float, optional
Threshold used to binarize predictions. Defaults to 0.5.
Returns
-------
false_alarm : torch.Tensor
missed_detection : torch.Tensor
speaker_confusion : torch.Tensor
speech_total : torch.Tensor
Diarization error rate components accumulated over the whole batch.
"""

# TODO: consider doing the permutation before the binarization
# in order to improve robustness to mis-calibration.
preds_bin = (preds > threshold).float()

# convert to/from "permutate" expected shapes
hypothesis, _ = permutate(
torch.transpose(target, 1, 2), torch.transpose(preds_bin, 1, 2)
)
hypothesis = torch.transpose(hypothesis, 1, 2)

detection_error = torch.sum(hypothesis, 1) - torch.sum(target, 1)
false_alarm = torch.maximum(detection_error, torch.zeros_like(detection_error))
missed_detection = torch.maximum(
-detection_error, torch.zeros_like(detection_error)
)

speaker_confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm

false_alarm = torch.sum(false_alarm)
missed_detection = torch.sum(missed_detection)
speaker_confusion = torch.sum(speaker_confusion)
speech_total = 1.0 * torch.sum(target)

return false_alarm, missed_detection, speaker_confusion, speech_total


def _der_compute(
false_alarm: torch.Tensor,
missed_detection: torch.Tensor,
speaker_confusion: torch.Tensor,
speech_total: torch.Tensor,
) -> torch.Tensor:
"""Compute diarization error rate from its components
Parameters
----------
false_alarm : torch.Tensor
missed_detection : torch.Tensor
speaker_confusion : torch.Tensor
speech_total : torch.Tensor
Diarization error rate components, in number of frames.
Returns
-------
der : torch.Tensor
Diarization error rate.
"""

# TODO: handle corner case where speech_total == 0
return (false_alarm + missed_detection + speaker_confusion) / speech_total


def diarization_error_rate(
preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5
) -> torch.Tensor:
"""Compute diarization error rate
Parameters
----------
preds : torch.Tensor
(batch_size, num_speakers, num_frames)-shaped continuous predictions.
target : torch.Tensor
(batch_size, num_speakers, num_frames)-shaped (0 or 1) targets.
threshold : float, optional
Threshold to binarize predictions. Defaults to 0.5.
Returns
-------
der : torch.Tensor
Aggregated diarization error rate
"""
false_alarm, missed_detection, speaker_confusion, speech_total = _der_update(
preds, target, threshold=threshold
)
return _der_compute(false_alarm, missed_detection, speaker_confusion, speech_total)

0 comments on commit 1dc198d

Please sign in to comment.