Skip to content

Commit

Permalink
Add AudioMNIST scenario (#3093)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Oct 24, 2024
1 parent 338d4bc commit e3c3366
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/helm/benchmark/run_specs/audio_run_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Run spec functions for audio scenarios."""

from typing import List, Optional
from helm.benchmark.adaptation.adapter_spec import (
AdapterSpec,
)
from helm.benchmark.adaptation.adapters.adapter_factory import ADAPT_GENERATION_MULTIMODAL
from helm.benchmark.metrics.common_metric_specs import (
get_classification_metric_specs,
get_exact_match_metric_specs,
)
from helm.benchmark.run_spec import RunSpec, run_spec_function
from helm.benchmark.scenarios.scenario import ScenarioSpec


def _get_multimodal_generation_adapter_spec(
max_tokens: int,
instructions: str = "",
max_train_instances: int = 0,
temperature: float = 0.0,
stop_sequences: Optional[List[str]] = None,
) -> AdapterSpec:
return AdapterSpec(
method=ADAPT_GENERATION_MULTIMODAL,
instructions=instructions,
input_prefix="",
input_suffix="",
output_prefix="",
output_suffix="",
instance_prefix="",
max_train_instances=max_train_instances,
num_outputs=1,
max_tokens=max_tokens,
temperature=temperature,
stop_sequences=stop_sequences if stop_sequences is not None else [],
)


@run_spec_function("audio_mnist")
def get_audio_mnist_run_spec() -> RunSpec:
scenario_spec = ScenarioSpec(class_name="helm.benchmark.scenarios.audio_scenarios.AudioMNISTScenario")
adapter_spec = _get_multimodal_generation_adapter_spec(
instructions="Classify the spoken digit. Respond with only a single digit.",
max_tokens=5,
)
metric_specs = get_exact_match_metric_specs() + get_classification_metric_specs()
return RunSpec(
name="audio_mnist",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=["audio_mnist"],
)
64 changes: 64 additions & 0 deletions src/helm/benchmark/scenarios/audio_scenarios.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Scenarios for audio models"""

from typing import List

from helm.benchmark.scenarios.scenario import (
Scenario,
Instance,
Reference,
TEST_SPLIT,
CORRECT_TAG,
Input,
Output,
)
from helm.common.media_object import MediaObject, MultimediaObject


class AudioMNISTScenario(Scenario):
"""AudioMNIST
The AudioMNIST (Becker et al, 2023) dataset consists of a dataset of 30000 audio samples of
spoken digits (0-9) of 60 different speakers. The task is to classify the digit from the
audio sample.
Paper: https://arxiv.org/abs/1807.03418
Code: https://github.com/soerenab/AudioMNIST
Citation:
@article{audiomnist2023,
title = {AudioMNIST: Exploring Explainable Artificial Intelligence for audio analysis on a simple benchmark},
journal = {Journal of the Franklin Institute},
year = {2023},
issn = {0016-0032},
doi = {https://doi.org/10.1016/j.jfranklin.2023.11.038},
url = {https://www.sciencedirect.com/science/article/pii/S0016003223007536},
author = {Sören Becker and Johanna Vielhaben and Marcel Ackermann and Klaus-Robert Müller and Sebastian Lapuschkin and Wojciech Samek},
keywords = {Deep learning, Neural networks, Interpretability, Explainable artificial intelligence, Audio classification, Speech recognition},
}
""" # noqa: E501

NUM_SPEAKERS = 60
NUM_TRIALS = 50
WAV_URL_TEMPLATE = r"https://github.com/soerenab/AudioMNIST/raw/544b0f4bc65227e54332e665d5e02c24be6732c2/data/{speaker_id}/{digit}_{speaker_id}_{trial_index}.wav" # noqa: E501

name = "audio_mnist"
description = "Classify an audio sample of a spoken digit"
tags = ["audio", "classification"]

def get_instances(self, output_path: str) -> List[Instance]:
instances: List[Instance] = []
for digit in range(10):
for speaker_index in range(AudioMNISTScenario.NUM_SPEAKERS):
speaker_id = str(speaker_index).zfill(2)
for trial_index in range(AudioMNISTScenario.NUM_TRIALS):
wav_url = AudioMNISTScenario.WAV_URL_TEMPLATE.format(
digit=digit, speaker_id=speaker_id, trial_index=trial_index
)
input = Input(
multimedia_content=MultimediaObject([MediaObject(content_type="audio/wav", location=wav_url)])
)
references = [Reference(Output(text=str(digit)), tags=[CORRECT_TAG])]
# Don't need train split because we're using zero-shot
instance = Instance(input=input, references=references, split=TEST_SPLIT)
instances.append(instance)
return instances
135 changes: 135 additions & 0 deletions src/helm/benchmark/static/schema_audio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
---
############################################################
metrics:
# Infrastructure metrics:
- name: num_perplexity_tokens
display_name: '# tokens'
description: Average number of tokens in the predicted output (for language modeling, the input too).
- name: num_bytes
display_name: '# bytes'
description: Average number of bytes in the predicted output (for language modeling, the input too).

- name: num_references
display_name: '# ref'
description: Number of references.
- name: num_train_trials
display_name: '# trials'
description: Number of trials, where in each trial we choose an independent, random set of training instances.
- name: estimated_num_tokens_cost
display_name: 'cost'
description: An estimate of the number of tokens (including prompt and output completions) needed to perform the request.
- name: num_prompt_tokens
display_name: '# prompt tokens'
description: Number of tokens in the prompt.
- name: num_prompt_characters
display_name: '# prompt chars'
description: Number of characters in the prompt.
- name: num_completion_tokens
display_name: '# completion tokens'
description: Actual number of completion tokens (over all completions).
- name: num_output_tokens
display_name: '# output tokens'
description: Actual number of output tokens.
- name: max_num_output_tokens
display_name: 'Max output tokens'
description: Maximum number of output tokens (overestimate since we might stop earlier due to stop sequences).
- name: num_requests
display_name: '# requests'
description: Number of distinct API requests.
- name: num_instances
display_name: '# eval'
description: Number of evaluation instances.
- name: num_train_instances
display_name: '# train'
description: Number of training instances (e.g., in-context examples).
- name: prompt_truncated
display_name: truncated
description: Fraction of instances where the prompt itself was truncated (implies that there were no in-context examples).
- name: finish_reason_length
display_name: finish b/c length
description: Fraction of instances where the the output was terminated because of the max tokens limit.
- name: finish_reason_stop
display_name: finish b/c stop
description: Fraction of instances where the the output was terminated because of the stop sequences.
- name: finish_reason_endoftext
display_name: finish b/c endoftext
description: Fraction of instances where the the output was terminated because the end of text token was generated.
- name: finish_reason_unknown
display_name: finish b/c unknown
description: Fraction of instances where the the output was terminated for unknown reasons.
- name: num_completions
display_name: '# completions'
description: Number of completions.
- name: predicted_index
display_name: Predicted index
description: Integer index of the reference (0, 1, ...) that was predicted by the model (for multiple-choice).

# Accuracy metrics:
- name: exact_match
display_name: Exact match
short_display_name: EM
description: Fraction of instances that the predicted output matches a correct reference exactly.
lower_is_better: false

############################################################
perturbations: []

############################################################
metric_groups:
- name: accuracy
display_name: Accuracy
hide_win_rates: true
metrics:
- name: ${main_name}
split: ${main_split}

- name: efficiency
display_name: Efficiency
metrics:
- name: inference_runtime
split: ${main_split}

- name: general_information
display_name: General information
hide_win_rates: true
metrics:
- name: num_instances
split: ${main_split}
- name: num_train_instances
split: ${main_split}
- name: prompt_truncated
split: ${main_split}
- name: num_prompt_tokens
split: ${main_split}
- name: num_output_tokens
split: ${main_split}

############################################################

run_groups:
- name: audio_scenarios
display_name: Audio Scenarios
description: Audio Scenarios
category: All scenarios
subgroups:
- audio_mnist

- name: audio_mnist
display_name: AudioMNIST
description: >
The AudioMNIST ([Becker et al, 2023](https://arxiv.org/abs/1807.03418)) dataset consists of a dataset of 30000 audio samples of
spoken digits (0-9) of 60 different speakers. The task is to classify the digit from the
audio sample.
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
main_split: test
taxonomy:
task: audio classification
what: audio samples of spoken digits (0-9)
who: 60 different speakers
when: "2018"
language: English

0 comments on commit e3c3366

Please sign in to comment.