Skip to content

Commit

Permalink
updated metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
WillemSpek committed Aug 7, 2023
1 parent ef45b68 commit 71c3083
Show file tree
Hide file tree
Showing 2 changed files with 414 additions and 15 deletions.
52 changes: 37 additions & 15 deletions relevance_maps_properties/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numpy.typing import NDArray
from tqdm import tqdm
from sklearn.metrics import auc
from scipy.stats import mode
from scipy.stats import mode, pearsonr, spearmanr
from copy import copy
from PIL.Image import Image
from torchtext.vocab import Vectors
Expand Down Expand Up @@ -297,10 +297,44 @@ def __init__(self, model: Union[Callable, str],
self.tokenizer = get_function(tokenizer, preprocess_function=None)
self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))
self.max_filter_size = max_filter_size
self.max_filter_size = max_filter_size
self.pad_token = pad_token
self.unk_token = unk_token

def __call__(self,
input_text: str,
salient_batch: list[list[tuple[str, str, float]]],
normalise: bool = False,
normalise_fn: Optional[Callable] = None,
impute_value: str = '<unk>',
p_thresh: float = .05
):

results = defaultdict(list)
for salience_map in salient_batch:
scores, init_score = self.evaluate(salience_map, input_text,
impute_value = impute_value)
scores = init_score - scores
_, _, relevances = self.sort_salience_map(salience_map)
if normalise:
relevances = normalise_fn(np.array(relevances))

pearson = pearsonr(relevances, scores, alternative='greater')
spearman = spearmanr(relevances, scores, alternative='greater')
if pearson[1] > p_thresh:
continue
if spearman[1] > p_thresh:
continue
results['scores'].append(scores.tolist())
results['pearson'].append(pearson[0])
results['spearman'].append(spearman[0])

if not results['scores']:
raise RuntimeError('''Couldn\'t find reliable correlation estimates in given
batch of explanations. This is likely due to a poor choice
of hyperparameters. Please recompute the explanations.''')
results['init_score'] = float(init_score)
return results

def evaluate(self,
salience_map: list[tuple[str, str, float]],
input_sentence: str,
Expand All @@ -320,35 +354,26 @@ def evaluate(self,
Perturbed sentence scores and initial sentence score
'''
# Tokenize setence.
# Tokenize setence.
tokenized = self._preprocess_sentence(input_sentence)
eval_sentence = copy(tokenized)
_, indices, _ = self.sort_salience_map(salience_map)
_, indices, _ = self.sort_salience_map(salience_map)

# Get original sentence score.
init_pred = self.model([eval_sentence], **model_kwargs)
init_score = init_pred.max()
init_lbl = init_pred.argmax()
# Get original sentence score.
init_pred = self.model([eval_sentence], **model_kwargs)
init_score = init_pred.max()
init_lbl = init_pred.argmax()


impute_value = self.vocab.stoi[impute_value]
scores = np.empty(len(salience_map))

for i, token_idx in enumerate(indices):
# Perturb sentence and score model.
# Perturb sentence and score model.
tmp = eval_sentence[token_idx]
eval_sentence[token_idx] = impute_value
score = self.model([eval_sentence], **model_kwargs).flatten()[init_lbl]
score = self.model([eval_sentence], **model_kwargs).flatten()[init_lbl]
eval_sentence[token_idx] = tmp
scores[i] = score
return scores, init_score
return scores, init_score

def _preprocess_sentence(self, input_sentence: str) -> list:
'''Tokenize and embed sentence.
Expand All @@ -360,8 +385,6 @@ def _preprocess_sentence(self, input_sentence: str) -> list:
tokens = self.tokenizer(input_sentence)
if len(tokens) < self.max_filter_size:
tokens += [self.pad_token] * (self.max_filter_size - len(tokens))
if len(tokens) < self.max_filter_size:
tokens += [self.pad_token] * (self.max_filter_size - len(tokens))

embedded = [self.vocab.stoi[token] if token in self.vocab.stoi
else self.vocab.stoi[self.unk_token] for token in tokens]
Expand All @@ -386,7 +409,6 @@ def visualize(self, salience_map: tuple[str, str, float],
'''
assert len(scores) >= len(salience_map)
words, indices, relevances = self.sort_salience_map(salience_map)
words, indices, relevances = self.sort_salience_map(salience_map)

fig, ax1 = plt.subplots()

Expand Down
Loading

0 comments on commit 71c3083

Please sign in to comment.