Skip to content

Commit

Permalink
ROUGE: fixed sentences calculation and some minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
rssdev10 committed Sep 12, 2023
1 parent 7cc7ab2 commit 66d7657
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 55 deletions.
25 changes: 11 additions & 14 deletions src/evaluation_metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ See also: [`rouge_l_sentence`](@ref), [`rouge_l_summary`](@ref)
"""
function rouge_n(references, candidate, n; avg = true, lang = Languages.English())
ng_candidate = ngramize(lang, candidate, n)
ng_refs = [ngramize(lang, ref, n) for ref in references]

rouge_recall = Array{Float64,1}()
for ref in ng_refs
push!(rouge_recall, rouge_match_score(keys(ref), ng_candidate) / sum(values(ref)) )
rouge_recall = map(references) do ref
ng_ref = ngramize(lang, ref, n)
rouge_match_score(keys(ng_ref), ng_candidate) / sum(values(ng_ref))
end

avg == true && return jackknife_avg(rouge_recall)
Expand All @@ -22,9 +20,9 @@ end

function rouge_match_score(ref, candidate::Dict)
matches = 0
for p in keys(candidate)
for (p, v) in candidate
p ref && continue
matches += candidate[p]
matches += v
end
return matches
end
Expand All @@ -40,12 +38,12 @@ See also: [`rouge_n`](@ref), [`rouge_l_summary`](@ref)
"""
function rouge_l_sentence(references, candidate, β=8, average = true)
ngram_cand = tokenize(Languages.English(), candidate)
rouge_l_list = []
rouge_l_list = Float64[]

for ref in references
ngram_ref = tokenize(Languages.English(), ref)
r_lcs = weighted_lcs(ngram_ref, ngram_cand, true, false, sqrt) / length(ngram_ref)
p_lcs = weighted_lcs(ngram_ref, ngram_cand, true, false, sqrt) / length(ngram_cand)
r_lcs = weighted_lcs(ngram_ref, ngram_cand, true, sqrt) / length(ngram_ref)
p_lcs = weighted_lcs(ngram_ref, ngram_cand, true, sqrt) / length(ngram_cand)
score = fmeasure_lcs(r_lcs, p_lcs, β)
push!(rouge_l_list, score)
end
Expand All @@ -66,7 +64,7 @@ See [Rouge: A package for automatic evaluation of summaries](http://www.aclweb.o
See also: [`rouge_l_sentence()`](@ref), [`rouge_l_summary`](@ref)
"""
function rouge_l_summary(references, candidate, β, averaging=true)
rouge_l_list = []
rouge_l_list = Float64[]
cand_sent_list = split_sentences(candidate)

for ref in references
Expand All @@ -75,11 +73,10 @@ function rouge_l_summary(references, candidate, β, averaging=true)

for ref_sent in ref_sent_list
l_ = []
arg1 = tokenize(Languages.English(), ref)
arg1 = tokenize(Languages.English(), ref_sent)
for cand_sent in cand_sent_list
arg2 = tokenize(Languages.English(), cand_sent)
d = tokenize(Languages.English(), weighted_lcs(arg1, arg2, false, true, sqrt))
append!(l_,d)
append!(l_, weighted_lcs_tokens(arg1, arg2, false, sqrt))
end
sum_value += length(unique(l_))
end
Expand Down
81 changes: 40 additions & 41 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,66 +28,65 @@ end
Compute the Weighted Longest Common Subsequence of X and Y.
"""
function weighted_lcs(X, Y, weighted = true, returns_string = false, f = sqrt)
m, n = length(X), length(Y)
c_table = [zeros(n+1) for i in 1:m+1]
w_table = [zeros(n+1) for i in 1:m+1]
function weighted_lcs(X, Y, weighted=true, f=sqrt)
result = weighted_lcs_inner(X, Y, weighted, f)

for i in 2:(m+1)
for j in 2:(n+1)
if X[i-1] == Y[j-1]
k = w_table[i-1][j-1]
if weighted == true
increment = (f(k+1)) - (f(k))
else
increment = 1
end
c_table[i][j] = c_table[i-1][j-1] + increment
w_table[i][j] = k + 1
else
if c_table[i-1][j] > c_table[i][j-1]
c_table[i][j] = c_table[i-1][j]
w_table[i][j] = 0 # no match at i,j
else
c_table[i][j] = c_table[i][j-1]
w_table[i][j] = 0 # no match at i,j
end
end
end
end
return result.lcs_length
end

lcs_length = (c_table[m+1][n+1])
if returns_string == false
return lcs_length
end
function weighted_lcs_tokens(X, Y, weighted=true, f=sqrt)
m, n, c_table, w_table, lcs_length = weighted_lcs_inner(X, Y, weighted, f)

if weighted == true
lcs_length = c_table[m][n]^(2)
end
# if weighted == true
# lcs_length = c_table[m, n]^(2) # ?....
# end

lcs_length = round(Int64, lcs_length)
lcs_length = convert(Int64, lcs_length)
lcs = ["" for i in 1:(lcs_length+1)]
lcs[lcs_length+1] = ""
lcs = ["" for i in 1:(lcs_length+1)]
i = m + 1
j = n + 1

while i>1 && j>1
while i > 1 && j > 1
if X[i-1] == Y[j-1]
lcs[lcs_length+1] = X[i-1]
i -= 1
j -= 1
lcs_length -= 1
elseif c_table[i-1][j] > c_table[i][j-1]
elseif c_table[i-1, j] > c_table[i, j-1]
i -= 1
else
j -= 1
end
end

return (join(lcs, " ")) # the lcs string
return lcs # the lcs string
end

function weighted_lcs_inner(X, Y, weighted=true, f=sqrt)
m, n = length(X), length(Y)
c_table = zeros(Int32, m + 1, n + 1)
w_table = zeros(Int32, m + 1, n + 1)
increment = 1

for i in 2:(m+1)
for j in 2:(n+1)
if X[i-1] == Y[j-1]
k = w_table[i-1, j-1]
if weighted == true
increment = (f(k + 1)) - (f(k))
end
c_table[i, j] = c_table[i-1, j-1] + increment
w_table[i, j] = k + 1
else
c_table[i, j] = max(c_table[i-1, j], c_table[i, j-1])
w_table[i, j] = 0 # no match at i,j
end
end
end

(m=m, n=n, c_table=c_table, w_table=w_table, lcs_length=c_table[m+1, n+1])
end


"""
fmeasure_lcs(RLCS, PLCS, β)
Expand All @@ -101,7 +100,7 @@ Compute the F-measure based on WLCS.
"""
function fmeasure_lcs(RLCS, PLCS, β=1)
try
return ((1 + β ^ 2) * RLCS * PLCS) / (RLCS + ^ 2) * PLCS)
return ((1 + β^2) * RLCS * PLCS) / (RLCS +^2) * PLCS)
catch ex
if ex isa DivideError
return 0
Expand Down

0 comments on commit 66d7657

Please sign in to comment.