Skip to content

Commit

Permalink
correction evaluation updates in correcttreebank
Browse files Browse the repository at this point in the history
  • Loading branch information
JanOdijk committed Sep 12, 2024
1 parent e1e98b8 commit 13682ae
Showing 1 changed file with 87 additions and 98 deletions.
185 changes: 87 additions & 98 deletions src/sastadev/correcttreebank.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import defaultdict
from copy import copy, deepcopy
from dataclasses import dataclass
import os
from typing import Dict, List, Optional, Set, Tuple
from typing import Callable, Dict, List, Optional, Set, Tuple

from lxml import etree

Expand Down Expand Up @@ -36,6 +37,9 @@

ampersand = '&'

positive = +1
negative = -1

corr0, corr1, corrn = '0', '1', 'n'
validcorroptions = [corr0, corr1, corrn]

Expand All @@ -55,55 +59,29 @@
ParsedCorrection = Tuple[List[str], SynTree, List[Meta]]
TupleNint = Tuple[19 * (int,)]

altpropertiesheader = ['penalty', 'dpcount', 'dhyphencount', 'mainclausecount', 'topclause', 'complsucount', 'dimcount', 'compcount', 'supcount',
'compoundcount', 'unknownwordcount', 'wrongposwordcount', 'smainsucount', 'sucount', 'svaokcount', 'deplusneutcount', 'badcatcount',
'hyphencount', 'lonelytoecount', 'basicreplaceecount', 'ambigcount', 'subjunctivecount', 'unknownnouncount',
'unknownnamecount', 'dezebwcount', 'noun1c_count']

errorwbheader = ['Sample', 'User1', 'User2', 'User3'] + \
['Status', 'Uttid', 'Origutt', 'Origsent'] + \
['altid', 'altsent', 'score'] + \
altpropertiesheader

smartreplacepairs = [('me', 'mijn'), ('ze', 'zijn')]
smartreplacedict = {w1: w2 for w1, w2 in smartreplacepairs}



class Criterion():
def __init__(self, name, getfunction, polarity, description):
self.name: str = name
self. getfunction: Callable[[SynTree], bool] = getfunction
self.polarity: int = polarity
self.description: str = description



class Alternative():
def __init__(self, stree, altid, altsent, penalty, dpcount, dhyphencount, mainclausecount, topclause, complsucount, dimcount,
compcount, supcount, compoundcount, unknownwordcount, wrongposwordcount, smainsucount, sucount, svaok, deplusneutcount, badcatcount,
hyphencount, lonelytoecount,
basicreplaceecount, ambigcount, subjunctivecount, unknownnouncount, unknownnamecount,
dezebwcount, noun1c_count):
def __init__(self, stree, altid, altsent, penalty, criteria):
self.stree: SynTree = stree
self.altid: AltId = altid
self.altsent: str = altsent
self.penalty: Penalty = int(penalty)
self.dpcount: int = int(dpcount)
self.dhyphencount: int = int(dhyphencount)
self.mainclausecount: int = int(mainclausecount)
self.topclause: int = int(topclause)
self.complsucount: int = int(complsucount)
self.dimcount: int = int(dimcount)
self.compcount: int = int(compcount)
self.supcount: int = int(supcount)
self.compoundcount: int = int(compoundcount)
self.unknownwordcount: int = int(unknownwordcount)
self.wrongposwordcount: int = int(wrongposwordcount)
self.smainsucount: int = int(smainsucount)
self.sucount: int = int(sucount)
self.svaok: int = int(svaok)
self.deplusneutcount: int = int(deplusneutcount)
self.badcatcount: int = int(badcatcount)
self.hyphencount: int = int(hyphencount)
self.lonelytoecount : int = int(lonelytoecount)
self.basicreplaceecount: int = int(basicreplaceecount)
self.ambigcount: int = int(ambigcount)
self.subjunctivecount = int(subjunctivecount)
self.unknownnouncount = int(unknownnouncount)
self.unknownnamecount = int(unknownnamecount)
self.dezebwcount = int(dezebwcount)
self.noun1c_count : int = int(noun1c_count)
self.criteria = criteria + [penalty]

def alt2row(self, uttid: UttId, base: str, user1: str = '', user2: str = '', user3: str = '',
bestaltids: List[AltId] = [],
Expand All @@ -117,13 +95,7 @@ def alt2row(self, uttid: UttId, base: str, user1: str = '', user2: str = '', use
scores.append('IDENTICAL')
score = ampersand.join(scores)
part4 = list(
map(str, [self.altid, self.altsent, score, self.penalty, self.dpcount, self.dhyphencount,
self.mainclausecount, self.topclause, self.complsucount,
self.dimcount, self.compcount, self.supcount, self.compoundcount, self.unknownwordcount,
self.wrongposwordcount, self.smainsucount, self.sucount,
self.svaok, self.deplusneutcount, self.badcatcount, self.hyphencount, self.lonelytoecount,
self.basicreplaceecount, self.ambigcount, self.subjunctivecount, self.unknownnouncount,
self.unknownnamecount, self.dezebwcount, self.noun1c_count]))
map(str, [self.altid, self.altsent, score] + self.criteria))
therow: list = [base, user1, user2, user3] + \
['Alternative', uttid] + 2 * [''] + part4

Expand Down Expand Up @@ -975,14 +947,7 @@ def oldgetuttid(stree: SynTree) -> UttId:


def scorefunction(obj: Alternative) -> TupleNint:
return (-obj.unknownwordcount, -obj.wrongposwordcount,-obj.unknownnouncount, -obj.unknownnamecount, -obj.ambigcount, -obj.dpcount,
-obj.dhyphencount, -obj.mainclausecount, obj.topclause,
-obj.complsucount, -obj.badcatcount,
-obj.basicreplaceecount, -obj.ambigcount, -obj.hyphencount, -obj.lonelytoecount,
-obj.subjunctivecount, obj.smainsucount, obj.dimcount,
obj.compcount, obj.supcount, obj.compoundcount, obj.sucount, obj.svaok,
-obj.deplusneutcount,
-obj.dezebwcount, -obj.noun1c_count, -obj.penalty)
return tuple(obj.criteria)


def getbestaltids(alts: Dict[AltId, Alternative]) -> List[AltId]:
Expand Down Expand Up @@ -1064,6 +1029,33 @@ def getwrongposwordcount(nt: SynTree) -> int:
result += len(matches)
return result


getdpcount = lambda nt: countav(nt, 'rel', 'dp')
getdhyphencount = lambda nt: countav(nt, 'rel', '--')
getdimcount = lambda nt: countav(nt, 'graad', 'dim')
getcompcount = lambda nt: countav(nt, 'graad', 'comp')
getsupcount = lambda nt: countav(nt, 'graad', 'sup')
getsucount = lambda nt: countav(nt, 'rel', 'su')
getbadcatcount = lambda nt: len(
[node for node in nt.xpath('.//node[@cat and (@cat="du") and node[@rel="dp"]]')])
gethyphencount = lambda nt: len(
[node for node in nt.xpath('.//node[contains(@word, "-")]')])
getbasicreplaceecount = lambda nt: len([node for node in nt.xpath('.//node[@word]')
if getattval(node, 'word').lower() in basicreplacements])
getsubjunctivecount = lambda nt: len(
[node for node in nt.xpath('.//node[@pvtijd="conj"]')])
getunknownnouncount = lambda nt: len([node for node in nt.xpath(
'.//node[@pt="n" and @frame="noun(both,both,both)"]')])
getunknownnamecount = lambda nt: len([node for node in nt.xpath(
'.//node[@pt="n" and @frame="proper_name(both)"]')])
complsuxpath = expandmacros(""".//node[node[(@rel="ld" or @rel="pc") and
@end<=../node[@rel="su"]/@begin and @begin >= ../node[@rel="hd"]/@end] and
not(node[%Rpronoun%])]""")
getcomplsucount = lambda nt: len([node for node in nt.xpath(complsuxpath)])
getdezebwcount = lambda nt: len([node for node in nt.xpath(dezebwxpath)])
getnoun1c_count = lambda nt: len([node for node in nt.xpath(noun1cxpath)])


def selectcorrection(stree: SynTree, ptmds: List[ParsedCorrection], corr: CorrectionMode) -> Tuple[
ParsedCorrection, OrigandAlts]:
# to be implemented@@
Expand All @@ -1077,57 +1069,17 @@ def selectcorrection(stree: SynTree, ptmds: List[ParsedCorrection], corr: Correc
for cw, nt, md in ptmds:
altsent = space.join(cw)
penalty = compute_penalty(md)
dpcount = countav(nt, 'rel', 'dp')
dhyphencount = countav(nt, 'rel', '--')
dimcount = countav(nt, 'graad', 'dim')
compcount = countav(nt, 'graad', 'comp')
supcount = countav(nt, 'graad', 'sup')
compoundcount = getcompoundcount(nt)
unknownwordcount = getunknownwordcount(nt)
wrongposwordcount = getwrongposwordcount(nt)
sucount = countav(nt, 'rel', 'su')
lonelytoecount = getlonelytoecount(nt)
mainclausecount = getmainclausecount(nt)
topclause = gettopclause(nt)
smainsucount = countsmainsu(nt)
svaokcount = getsvaokcount(nt)
deplusneutcount = getdeplusneutcount(nt)
badcatcount = len(
[node for node in nt.xpath('.//node[@cat and (@cat="du") and node[@rel="dp"]]')])
hyphencount = len(
[node for node in nt.xpath('.//node[contains(@word, "-")]')])
basicreplaceecount = len([node for node in nt.xpath('.//node[@word]')
if getattval(node, 'word').lower() in basicreplacements])
ambigwordcount = countambigwords(nt)
subjunctivecount = len(
[node for node in nt.xpath('.//node[@pvtijd="conj"]')])
unknownnouncount = len([node for node in nt.xpath(
'.//node[@pt="n" and @frame="noun(both,both,both)"]')])
unknownnamecount = len([node for node in nt.xpath(
'.//node[@pt="n" and @frame="proper_name(both)"]')])
complsuxpath = expandmacros(""".//node[node[(@rel="ld" or @rel="pc") and
@end<=../node[@rel="su"]/@begin and @begin >= ../node[@rel="hd"]/@end] and
not(node[%Rpronoun%])]""")
complsucount = len([node for node in nt.xpath(complsuxpath)])
dezebwcount = len([node for node in nt.xpath(dezebwxpath)])
noun1c_count = len([node for node in nt.xpath(noun1cxpath)])
# overregcount but these will mostly be unknown words
# mwunamecount well maybe unknownpropernoun first

alt = Alternative(stree, altid, altsent, penalty, dpcount, dhyphencount, mainclausecount, topclause, complsucount, dimcount, compcount,
supcount,
compoundcount, unknownwordcount, wrongposwordcount, smainsucount, sucount, svaokcount, deplusneutcount, badcatcount,
hyphencount, lonelytoecount,
basicreplaceecount, ambigwordcount, subjunctivecount, unknownnouncount,
unknownnamecount, dezebwcount, noun1c_count)

criteriavalues = [criterion.getfunction(nt) * criterion.polarity for criterion in criteria]

alt = Alternative(stree, altid, altsent, penalty * negative, criteriavalues)
alts[altid] = alt
altid += 1
orandalts = OrigandAlts(orig, alts)

if corr == corr1:
orandalts.selected = altid - 1
elif corr == corrn:
# @@to be implemented@@
bestaltids = getbestaltids(alts)
if bestaltids != []:
bestaltid = bestaltids[0] # or perhaps better the last one?
Expand Down Expand Up @@ -1204,3 +1156,40 @@ def compute_penalty(md: List[Meta]) -> Penalty:
for meta in md:
totalpenalty += meta.penalty
return totalpenalty

# The constant *criteria* is a list of objects of class *Criterion* that are used, in the order given, to evaluate parses
criteria = [
Criterion("unknownwordcount", getunknownwordcount, negative, "Number of unknown words"),
Criterion("wrongposwordcount", getwrongposwordcount, negative, "Numbe rof words with the wrong part of speech"),
Criterion("unknownnouncount", getunknownnouncount, negative, "Count of unknown nouns according to Alpino"),
Criterion("unknownnamecount", getunknownnamecount, negative, "Count of unknown names"),
Criterion("ambigcount", countambigwords, negative, "Number of ambiguous words"),
Criterion("dpcount", getdpcount, negative, "Number of nodes with relation dp"),
Criterion("dhyphencount", getdhyphencount, negative, "Number of nodes with relation --"),
Criterion("mainclausecount", getmainclausecount, negative, "Number of main clauses"),
Criterion("topclause", gettopclause, positive, "Single clause under top"),
Criterion("complsucount", getcomplsucount, negative, ""),
Criterion("badcatcount", getbadcatcount, negative, "Count of bad categories: du that contains a node with relation dp"),
Criterion("basicreplaceecount", getbasicreplaceecount, negative, "Number of words from the basic replacements"),
Criterion("hyphencount", gethyphencount, negative, "Number rof words that contain hyphens"),
Criterion("lonelytoecount", getlonelytoecount, negative, "Number of occurrences of lonely 'toe'"),
Criterion("subjunctivecount", getsubjunctivecount, negative, "Number of subjunctive verb forms"),
Criterion("smainsucount", countsmainsu, positive, "Count of smain nodes that contain a subject"),
Criterion("dimcount", getdimcount, positive, "Number of words that are diminutives"),
Criterion("compcount", getcompcount, positive, "Number of words that are comparatives"),
Criterion("supcount", getsupcount, positive, "Number of words that are superlatives"),
Criterion("compoundcount", getcompoundcount, positive, "Number of nouns that are compounds"),
Criterion("sucount", getsucount, positive, "Number of subjects"),
Criterion("svaok", getsvaokcount, positive, "Numbe rof time subject verb agreement is OK"),
Criterion("deplusneutcount", getdeplusneutcount, negative, "Number of deviant configuratios with de-determeine + neuiter noun"),
Criterion("dezebwcount", getdezebwcount, negative, "Count of 'deze' as adverb"),
Criterion("noun1c_count", getnoun1c_count, negative, "Number of nouns that consist of a single character"),
# Criterion("penalty", compute_penalty, negative, "Penalty for the changes made") # not in here, added later in Alternative
]

altpropertiesheader = [criterion.name for criterion in criteria] + ['penalty']

errorwbheader = ['Sample', 'User1', 'User2', 'User3'] + \
['Status', 'Uttid', 'Origutt', 'Origsent'] + \
['altid', 'altsent', 'score'] + \
altpropertiesheader

0 comments on commit 13682ae

Please sign in to comment.