-
Notifications
You must be signed in to change notification settings - Fork 3
/
4-estimate-leading-char.R
141 lines (117 loc) · 2.84 KB
/
4-estimate-leading-char.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
library(luz)
library(yardstick)
library(tibble)
library(stringr)
source("alpha-char-model.R")
ccc = c(
"^[AB].*",
"(^C|^D[0-4]).*",
"^D[5-8].*",
"^E[0-8][0-9].*",
"^F.*",
"^G.*",
"^H[0-5][0-9].*",
"^H[6-9][0-9].*",
"^I.*",
"^J.*",
"^K.*",
"^L.*",
"^M.*",
"^N.*",
"^O[0-9].*",
"^P.*",
"^Q.*",
"^R.*",
"^[ST].*",
"^[UVWXY].*",
"^[Z].*"
)
get_short_code_impl = function(code) {
which(map_lgl(ccc, ~ grepl(.x, code)))
}
get_short_code = function(code) {
map_int(code, get_short_code_impl)
}
emb_data_dir = "embedding-data"
params = tibble(
embedding_files =
file.path(emb_data_dir, dir(emb_data_dir) |> str_subset("2019")),
emb_dim =
str_extract(embedding_files, "-\\d{4}\\.") |> str_extract("\\d{4}")
)
dir.create("luz-supervised-models")
ms = list()
for (i in seq_len(nrow(params))) {
aced = params$embedding_files[i]|> read_csv()
traini = sample.int(nrow(aced), round(0.9 * nrow(aced)))
testi = setdiff(seq_len(nrow(aced)), traini)
aced$code = get_short_code(aced$code)
train = AlphaCharEmbedding(aced[traini, ], sort(unique(aced$code)))
test = AlphaCharEmbedding(aced[testi, ], sort(unique(aced$code)))
layers = c(train$width(), 100, 100, 21)
batch_size = 64
epochs = 30
# Cross entropy
loss = function(input, target) {
torch_mean(-torch_sum(target * torch_log(input + 1e-16), 2))
}
luz_model = AlphaCodeEstimator |>
setup(
loss = loss, #nn_cross_entropy_loss(26),
optimizer = optim_adam
) |>
set_hparams(layers = layers) |>
fit(
data = dataloader(
train,
batch_size = batch_size,
shuffle = TRUE,
num_workers = 4,
worker_packages = c("torch", "dplyr")
),
epochs = epochs,
valid_data = dataloader(
test,
batch_size = batch_size,
shuffle = TRUE,
num_workers = 4,
worker_packages = c("torch", "dplyr")
),
callbacks = list(
luz_callback_keep_best_model()
)
)
luz_save(
luz_model,
file.path("luz-supervised-models",
sprintf("luz-model-%s.pt", params$emb_dim[i]))
)
preds =
predict(
luz_model,
dataloader(
test,
batch_size = batch_size,
num_workers = 4,
worker_packages = c("torch", "dplyr")
)
)
comp = tibble(
obs = aced[testi,]$code |>
factor(levels = 1:21),
pred = preds |>
torch_tensor(device = "cpu") |>
as.matrix() |>
apply(1, which.max) |>
factor(levels = 1:21)
)
ms = c(ms,
list(
metric_set(accuracy, bal_accuracy)(comp, truth = obs, estimate = pred)
)
)
print(ms)
}
params$accuracy = c(ms[[3]][1], ms[[6]][1], ms[[9]][1], ms[[12]][1])
params$bal_accuracy = c(ms[[3]][2], ms[[6]][2], ms[[9]][2], ms[[12]][2])
saveRDS(params |> select(-embedding_files), "sup-model-perf.rds")