-
Notifications
You must be signed in to change notification settings - Fork 1
/
na_attention.py
81 lines (70 loc) · 2.63 KB
/
na_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import shutil
from diagnnose.config.arg_parser import create_arg_parser
from diagnnose.config.setup import create_config_dict
from diagnnose.corpus.import_corpus import import_corpus
from diagnnose.decompositions.attention import CDAttention
from diagnnose.models.import_model import import_model
from diagnnose.models.lm import LanguageModel
from diagnnose.typedefs.corpus import Corpus
from diagnnose.vocab import get_vocab_from_config
from diagnnose.extractors.base_extractor import Extractor
if __name__ == "__main__":
arg_groups = {
"activations",
"corpus",
"decompose",
"extract",
"init_states",
"model",
"plot_attention",
"vocab",
}
arg_parser, required_args = create_arg_parser(arg_groups)
config_dict = create_config_dict(arg_parser, required_args, arg_groups)
model: LanguageModel = import_model(config_dict)
corpus: Corpus = import_corpus(
vocab_path=get_vocab_from_config(config_dict), **config_dict["corpus"]
)
# Extract the model activations
extractor = Extractor(model, corpus, **config_dict["activations"])
extractor.extract(**config_dict["extract"])
# Append the wrong verb to each sentence
for idx in range(0, len(corpus.examples), 2):
corpus.examples[idx].sen += [corpus.examples[idx + 1].sen[-1]]
if "fix_shapley" not in config_dict["decompose"]:
config_dict["decompose"]["fix_shapley"] = False
attention = CDAttention(
model,
corpus,
cd_config=config_dict["decompose"],
plot_config=config_dict["plot_attention"],
)
activations_dir = config_dict["activations"].get("activations_dir", None)
print("Creating plot for SP case...")
sen = ["The", "N$_{sing}$", "PREP", "the", "N$_{plur}$", "V$_{sing}$", "V$_{plur}$"]
attention.plot_config["xtext"] = sen[1:]
attention.plot_config["ytext"] = sen[:-2]
attention.plot_by_sen_id(
slice(1200, 2400, 2),
avg_decs=True,
activations_dir=activations_dir,
extra_classes=[-2],
)
print("Creating plot for PS case...")
sen = ["The", "N$_{plur}$", "PREP", "the", "N$_{sing}$", "V$_{plur}$", "V$_{sing}$"]
attention.plot_config["xtext"] = sen[1:]
attention.plot_config["ytext"] = sen[:-2]
attention.plot_by_sen_id(
slice(2400, 3600, 2),
avg_decs=True,
activations_dir=activations_dir,
extra_classes=[-2],
)
print("Creating example plot")
attention.plot_by_sen_id(
[1500],
activations_dir=activations_dir,
extra_classes=[-2],
)
if activations_dir is not None:
shutil.rmtree(activations_dir)