This repository has been archived by the owner on Jul 23, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an example task for frequency based mathematics
- Loading branch information
1 parent
6b95173
commit d0a0e5a
Showing
6 changed files
with
1,009 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
53 changes: 53 additions & 0 deletions
53
src/genbench/tasks/frequency_based_mathematics/config.jsonnet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
{ | ||
name: 'Frequency based mathematics', | ||
description: 'This sample submission measures generalisation in the domain of mathematical | ||
questions, by quantifying the extent to which correctness depends on the | ||
frequency of the underlying terms. A model is said to be a stronger at | ||
generalisation if its answers are less dependent on the term frequencies. | ||
This test is inspired by the work of Razeghi et al (2022). | ||
and the performance. | ||
', | ||
keywords: [ | ||
'mathematics', | ||
'term_frequencies', | ||
'prompting', | ||
'LLMs' | ||
], | ||
|
||
authors: [ | ||
'Dieuwke Hupkes', | ||
], | ||
|
||
data_source: { | ||
type: 'manual', | ||
test: 'https://raw.githubusercontent.com/dieuwkehupkes/genbench_cbt_sample_submission/frequency_math/src/genbench/tasks/frequency_based_mathematics/test_data.jsonl', | ||
}, | ||
|
||
has_validation_set: false, | ||
has_train_set: false, | ||
|
||
task_type: 'free_form', | ||
|
||
evaluation_metrics: [ | ||
{ | ||
hf_id: 'exact_match', | ||
git_commit_sha: "758135da6a37ce962b7bc38c6dd5eab672d2b742", | ||
best_score: 1.0, | ||
} | ||
], | ||
|
||
preparation_strategies: { | ||
// A recipe for preparing the model to perform the task by configuring its prompt. | ||
// This recipe is suitable for generative LMs such as GPT-3, OPT, T5, etc. | ||
// We provide a few options for configuring the prompt. But, the task creator can | ||
// also provide a custom prompt preparation in the task's Python class. | ||
prompt_based_testing: { | ||
prompt_builder: { | ||
instruction_zero_shot: '', // Left empty because the prompt is in the data | ||
instruction_few_shot: '', // Left empty because the prompt is in the data | ||
input_prefix: 'Q: ', | ||
output_prefix: '\nA: ', | ||
} | ||
}, | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Frequency based mathematics -- an example task | ||
|
||
This is an example task for measuring generalisation in LLMs, given the frequency | ||
of terms in their pretraining corpus. It quantifies memorisation in a simple | ||
multiplication task, by computing the extent to which the accuracy of a model | ||
depends on the frequency of the terms in the pretraining corpus. In particular, | ||
it defines generalisation as `1 - abs(cor(accuracy, pretraining_term_frequencies)`. | ||
We estimate the accuracy for a particular term by multiplying it with 20 other | ||
terms, and computing the average score. | ||
This measure is inspired by the work of [Razehgi et al (2022)](https://aclanthology.org/2022.findings-emnlp.59/), | ||
who measure the correlation between the pretraining term frequencies and a models performance. | ||
|
||
*NB* This task serves illustrative purposes and is meant to provide inspiration. The measure has | ||
not been tested on real data, and there are no empirical results (yet?) that prove that this | ||
measure is useful in practice. | ||
|
||
## Abstract | ||
Because this is an example task, there is no paper abstract. | ||
|
||
## Examples | ||
|
||
The samples in this task are simple arithmetics problems, e.g.: | ||
|
||
``` | ||
What is ten times 31? | ||
``` | ||
|
||
## Usage | ||
|
||
This task requires the user to compute the unigram frequencies of the numbers ten to fifty (spelled out). | ||
Those terms should be provided in a dictionary, with as keys the terms (e.g. `twenty`) and as value the | ||
(potentially normalised) frequency in the pretraining corpus. | ||
|
||
``` | ||
# Load the task | ||
task = load_task("frequency_based_mathematics") | ||
ds = task.get_prepared_datasets( | ||
PreparationStrategy.PROMPT_BASED_TESTING, | ||
shot_list=[0])[0] | ||
# Load your pretraining frequencies and model predictions | ||
pretraining_freqs = ... | ||
preds = ... | ||
for pred_type, preds in preds.items(): | ||
for freq_type, pretraining_freq in pretraining_freqs.items(): | ||
scores = task.evaluate_predictions( | ||
predictions=preds, | ||
gold=ds, | ||
term_freqs=pretraining_freq | ||
) | ||
print(f'Scores: {scores}') | ||
``` | ||
|
||
## Data Source | ||
The (dummy) test data is hosted at [https://github.com/dieuwkehupkes/genbench_cbt_sample_submission/blob/template/src/genbench/tasks/frequency_based_mathematics/test_data.jsonl], under an Apache 2.0 license. | ||
|
||
## Limitations and Bias | ||
This is an example task for illustrative purposes, it has not been tested empirically. | ||
|
||
## GenBench eval card | ||
This test can be used to test generalisation in LLMs (pretrain - test locus). | ||
It is designed to better understand how LLMs generalise (intrinsic motivation), and can be used to assess | ||
compositional generalisation or -- more generally -- robustness. | ||
Because the test (input) samples differ in their frequency distribution wrt the training corpus, | ||
we assume that the shift is a covariate shift. | ||
|
||
![GenBench Eval Card](eval_card.png) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Any, List, Dict | ||
|
||
import datasets | ||
from scipy.stats import pearsonr | ||
from sklearn.metrics import accuracy_score | ||
|
||
from genbench.utils.logging import get_logger | ||
from genbench import Task | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class FrequencyBasedMathematicsTask(Task): | ||
pass | ||
"""Python implementation of the FrequencyBasedMathematics task. | ||
This LLM task assess to what extent the computation of mathematical expressions | ||
depends on the frequency of the individual terms during pretraining. | ||
The test is inspired by the work of Razeghi et al (2022). | ||
It defines a generalisation score as 1 - abs(pearsonsR(term_freq, score)) | ||
To facilitate this, it reimplements the following default functions: | ||
- TODO list functions to be reimplemented | ||
- __init__ : this function checks if a file with term frequences is | ||
added by the user, and warns the user if it is not | ||
- evaluate_predictions: this function implements the generalisation | ||
metric used for this task | ||
""" | ||
|
||
def evaluate_predictions( | ||
self, | ||
*, | ||
predictions: List[Dict[str, Any]] = None, | ||
gold: datasets.Dataset = None, | ||
term_freqs: Dict[str, int] = None, | ||
) -> Dict[str, float]: | ||
"""Evaluate the predictions of the model against the gold data. | ||
For this task, the evaluation metric returns the minimum accuracy | ||
among different prompts (i.e. if one of the prompts returns a wrong | ||
answer, the score is 0 for that example). | ||
Args: | ||
predictions: A list of dictionaries, where each dictionary contains the predicted | ||
values for an example. The keys are strings and the values the model predictions. | ||
gold: A HuggingFace `datasets.Dataset` object containing the ground truth data for the task. | ||
term_freqs: A dictionary that maps the terms 10-30 to their frequencies in the pretraining corpus | ||
Returns: | ||
A dictionary containing key-value pairs for the evaluation metric(s) computed on the predicted | ||
values. The keys are strings representing the name of the evaluation metric and the values are | ||
floating-point numbers. | ||
""" | ||
|
||
if term_freqs is None: | ||
raise Exception("This evaluation requires information about term frequencies in the pretraining corpus, that should be provided by the user. For more information, we refer to the readme of this task.") | ||
|
||
scores, freqs = [], [] | ||
|
||
preds_list = [pred["target"] for pred in predictions] | ||
refs_list = [g["target"] for g in gold] | ||
|
||
ref_type = type(refs_list[0]) | ||
pred_type = type(preds_list[0]) | ||
|
||
# Make sure predictions and gold are the same type | ||
if pred_type != ref_type: | ||
if pred_type == str and ref_type == int: | ||
logger.warning("Predictions are strings, but references are ints. Converting predictions to ints.") | ||
elif pred_type == int and ref_type == str: | ||
logger.warning("Predictions are ints, but references are strings. Converting references to ints.") | ||
|
||
# In the data, there are 20 examples for each term | ||
# in range (10, 50), we compute the avg accuracy per term | ||
# by averaging over all samples with the term as first number | ||
|
||
for n in range(40): | ||
# Fetch all accuracy scores for the predictions | ||
for prediction in predictions: | ||
acc = accuracy_score(preds_list[n:n+20], refs_list[n:n+20]) | ||
freq = term_freqs[n+10] | ||
scores.append(acc) | ||
freqs.append(freq) | ||
|
||
corr = pearsonr(scores, freqs)[0] | ||
|
||
return {'gen_score': 1-abs(corr), 'accuracy': sum(scores)/len(scores)} |
Oops, something went wrong.