From 1dc198dba860602cbc54938ca79b714001f58357 Mon Sep 17 00:00:00 2001 From: FrenchKrab Date: Mon, 14 Mar 2022 15:59:04 +0100 Subject: [PATCH] feat: add DER torchmetrics (#909) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hervé BREDIN --- pyannote/audio/torchmetrics/__init__.py | 36 +++++ pyannote/audio/torchmetrics/audio/__init__.py | 36 +++++ .../audio/diarization_error_rate.py | 119 ++++++++++++++++ .../audio/torchmetrics/functional/__init__.py | 21 +++ .../torchmetrics/functional/audio/__init__.py | 21 +++ .../audio/diarization_error_rate.py | 128 ++++++++++++++++++ 6 files changed, 361 insertions(+) create mode 100644 pyannote/audio/torchmetrics/__init__.py create mode 100644 pyannote/audio/torchmetrics/audio/__init__.py create mode 100644 pyannote/audio/torchmetrics/audio/diarization_error_rate.py create mode 100644 pyannote/audio/torchmetrics/functional/__init__.py create mode 100644 pyannote/audio/torchmetrics/functional/audio/__init__.py create mode 100644 pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py diff --git a/pyannote/audio/torchmetrics/__init__.py b/pyannote/audio/torchmetrics/__init__.py new file mode 100644 index 000000000..27513524b --- /dev/null +++ b/pyannote/audio/torchmetrics/__init__.py @@ -0,0 +1,36 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .audio.diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) + +__all__ = [ + "DiarizationErrorRate", + "FalseAlarmRate", + "MissedDetectionRate", + "SpeakerConfusionRate", +] diff --git a/pyannote/audio/torchmetrics/audio/__init__.py b/pyannote/audio/torchmetrics/audio/__init__.py new file mode 100644 index 000000000..e500a2655 --- /dev/null +++ b/pyannote/audio/torchmetrics/audio/__init__.py @@ -0,0 +1,36 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from .diarization_error_rate import ( + DiarizationErrorRate, + FalseAlarmRate, + MissedDetectionRate, + SpeakerConfusionRate, +) + +__all__ = [ + "DiarizationErrorRate", + "SpeakerConfusionRate", + "MissedDetectionRate", + "FalseAlarmRate", +] diff --git a/pyannote/audio/torchmetrics/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py new file mode 100644 index 000000000..c294eb3a3 --- /dev/null +++ b/pyannote/audio/torchmetrics/audio/diarization_error_rate.py @@ -0,0 +1,119 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +from torchmetrics import Metric + +from pyannote.audio.torchmetrics.functional.audio.diarization_error_rate import ( + _der_compute, + _der_update, +) + + +class DiarizationErrorRate(Metric): + """Diarization error rate + + Parameters + ---------- + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Notes + ----- + While pyannote.audio conventions is to store speaker activations with + (batch_size, num_frames, num_speakers)-shaped tensors, this torchmetrics metric + expects them to be shaped as (batch_size, num_speakers, num_frames) tensors. + """ + + higher_is_better = False + is_differentiable = False + + def __init__(self, threshold: float = 0.5): + super().__init__() + + self.threshold = threshold + + self.add_state("false_alarm", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state( + "missed_detection", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state( + "speaker_confusion", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("speech_total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update( + self, + preds: torch.Tensor, + target: torch.Tensor, + ) -> None: + """Compute and accumulate components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=self.threshold + ) + self.false_alarm += false_alarm + self.missed_detection += missed_detection + self.speaker_confusion += speaker_confusion + self.speech_total += speech_total + + def compute(self): + return _der_compute( + self.false_alarm, + self.missed_detection, + self.speaker_confusion, + self.speech_total, + ) + + +class SpeakerConfusionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.speaker_confusion / self.speech_total + + +class FalseAlarmRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.false_alarm / self.speech_total + + +class MissedDetectionRate(DiarizationErrorRate): + def compute(self): + # TODO: handler corner case where speech_total == 0 + return self.missed_detection / self.speech_total diff --git a/pyannote/audio/torchmetrics/functional/__init__.py b/pyannote/audio/torchmetrics/functional/__init__.py new file mode 100644 index 000000000..67b544284 --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/__init__.py b/pyannote/audio/torchmetrics/functional/audio/__init__.py new file mode 100644 index 000000000..67b544284 --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/audio/__init__.py @@ -0,0 +1,21 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. diff --git a/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py new file mode 100644 index 000000000..77b4f3f3c --- /dev/null +++ b/pyannote/audio/torchmetrics/functional/audio/diarization_error_rate.py @@ -0,0 +1,128 @@ +# MIT License +# +# Copyright (c) 2022- CNRS +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +from typing import Tuple + +import torch + +from pyannote.audio.utils.permutation import permutate + + +def _der_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute components of diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold used to binarize predictions. Defaults to 0.5. + + Returns + ------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components accumulated over the whole batch. + """ + + # TODO: consider doing the permutation before the binarization + # in order to improve robustness to mis-calibration. + preds_bin = (preds > threshold).float() + + # convert to/from "permutate" expected shapes + hypothesis, _ = permutate( + torch.transpose(target, 1, 2), torch.transpose(preds_bin, 1, 2) + ) + hypothesis = torch.transpose(hypothesis, 1, 2) + + detection_error = torch.sum(hypothesis, 1) - torch.sum(target, 1) + false_alarm = torch.maximum(detection_error, torch.zeros_like(detection_error)) + missed_detection = torch.maximum( + -detection_error, torch.zeros_like(detection_error) + ) + + speaker_confusion = torch.sum((hypothesis != target) * hypothesis, 1) - false_alarm + + false_alarm = torch.sum(false_alarm) + missed_detection = torch.sum(missed_detection) + speaker_confusion = torch.sum(speaker_confusion) + speech_total = 1.0 * torch.sum(target) + + return false_alarm, missed_detection, speaker_confusion, speech_total + + +def _der_compute( + false_alarm: torch.Tensor, + missed_detection: torch.Tensor, + speaker_confusion: torch.Tensor, + speech_total: torch.Tensor, +) -> torch.Tensor: + """Compute diarization error rate from its components + + Parameters + ---------- + false_alarm : torch.Tensor + missed_detection : torch.Tensor + speaker_confusion : torch.Tensor + speech_total : torch.Tensor + Diarization error rate components, in number of frames. + + Returns + ------- + der : torch.Tensor + Diarization error rate. + """ + + # TODO: handle corner case where speech_total == 0 + return (false_alarm + missed_detection + speaker_confusion) / speech_total + + +def diarization_error_rate( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> torch.Tensor: + """Compute diarization error rate + + Parameters + ---------- + preds : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped continuous predictions. + target : torch.Tensor + (batch_size, num_speakers, num_frames)-shaped (0 or 1) targets. + threshold : float, optional + Threshold to binarize predictions. Defaults to 0.5. + + Returns + ------- + der : torch.Tensor + Aggregated diarization error rate + """ + false_alarm, missed_detection, speaker_confusion, speech_total = _der_update( + preds, target, threshold=threshold + ) + return _der_compute(false_alarm, missed_detection, speaker_confusion, speech_total)