From 13682ae2c8fca3b9a03edc5f8d3888892041e242 Mon Sep 17 00:00:00 2001 From: Jan Odijk Date: Thu, 12 Sep 2024 09:21:40 +0200 Subject: [PATCH] correction evaluation updates in correcttreebank --- src/sastadev/correcttreebank.py | 185 +++++++++++++++----------------- 1 file changed, 87 insertions(+), 98 deletions(-) diff --git a/src/sastadev/correcttreebank.py b/src/sastadev/correcttreebank.py index 4214552..be8b1ff 100644 --- a/src/sastadev/correcttreebank.py +++ b/src/sastadev/correcttreebank.py @@ -1,7 +1,8 @@ from collections import defaultdict from copy import copy, deepcopy +from dataclasses import dataclass import os -from typing import Dict, List, Optional, Set, Tuple +from typing import Callable, Dict, List, Optional, Set, Tuple from lxml import etree @@ -36,6 +37,9 @@ ampersand = '&' +positive = +1 +negative = -1 + corr0, corr1, corrn = '0', '1', 'n' validcorroptions = [corr0, corr1, corrn] @@ -55,55 +59,29 @@ ParsedCorrection = Tuple[List[str], SynTree, List[Meta]] TupleNint = Tuple[19 * (int,)] -altpropertiesheader = ['penalty', 'dpcount', 'dhyphencount', 'mainclausecount', 'topclause', 'complsucount', 'dimcount', 'compcount', 'supcount', - 'compoundcount', 'unknownwordcount', 'wrongposwordcount', 'smainsucount', 'sucount', 'svaokcount', 'deplusneutcount', 'badcatcount', - 'hyphencount', 'lonelytoecount', 'basicreplaceecount', 'ambigcount', 'subjunctivecount', 'unknownnouncount', - 'unknownnamecount', 'dezebwcount', 'noun1c_count'] -errorwbheader = ['Sample', 'User1', 'User2', 'User3'] + \ - ['Status', 'Uttid', 'Origutt', 'Origsent'] + \ - ['altid', 'altsent', 'score'] + \ - altpropertiesheader smartreplacepairs = [('me', 'mijn'), ('ze', 'zijn')] smartreplacedict = {w1: w2 for w1, w2 in smartreplacepairs} + +class Criterion(): + def __init__(self, name, getfunction, polarity, description): + self.name: str = name + self. getfunction: Callable[[SynTree], bool] = getfunction + self.polarity: int = polarity + self.description: str = description + + + class Alternative(): - def __init__(self, stree, altid, altsent, penalty, dpcount, dhyphencount, mainclausecount, topclause, complsucount, dimcount, - compcount, supcount, compoundcount, unknownwordcount, wrongposwordcount, smainsucount, sucount, svaok, deplusneutcount, badcatcount, - hyphencount, lonelytoecount, - basicreplaceecount, ambigcount, subjunctivecount, unknownnouncount, unknownnamecount, - dezebwcount, noun1c_count): + def __init__(self, stree, altid, altsent, penalty, criteria): self.stree: SynTree = stree self.altid: AltId = altid self.altsent: str = altsent self.penalty: Penalty = int(penalty) - self.dpcount: int = int(dpcount) - self.dhyphencount: int = int(dhyphencount) - self.mainclausecount: int = int(mainclausecount) - self.topclause: int = int(topclause) - self.complsucount: int = int(complsucount) - self.dimcount: int = int(dimcount) - self.compcount: int = int(compcount) - self.supcount: int = int(supcount) - self.compoundcount: int = int(compoundcount) - self.unknownwordcount: int = int(unknownwordcount) - self.wrongposwordcount: int = int(wrongposwordcount) - self.smainsucount: int = int(smainsucount) - self.sucount: int = int(sucount) - self.svaok: int = int(svaok) - self.deplusneutcount: int = int(deplusneutcount) - self.badcatcount: int = int(badcatcount) - self.hyphencount: int = int(hyphencount) - self.lonelytoecount : int = int(lonelytoecount) - self.basicreplaceecount: int = int(basicreplaceecount) - self.ambigcount: int = int(ambigcount) - self.subjunctivecount = int(subjunctivecount) - self.unknownnouncount = int(unknownnouncount) - self.unknownnamecount = int(unknownnamecount) - self.dezebwcount = int(dezebwcount) - self.noun1c_count : int = int(noun1c_count) + self.criteria = criteria + [penalty] def alt2row(self, uttid: UttId, base: str, user1: str = '', user2: str = '', user3: str = '', bestaltids: List[AltId] = [], @@ -117,13 +95,7 @@ def alt2row(self, uttid: UttId, base: str, user1: str = '', user2: str = '', use scores.append('IDENTICAL') score = ampersand.join(scores) part4 = list( - map(str, [self.altid, self.altsent, score, self.penalty, self.dpcount, self.dhyphencount, - self.mainclausecount, self.topclause, self.complsucount, - self.dimcount, self.compcount, self.supcount, self.compoundcount, self.unknownwordcount, - self.wrongposwordcount, self.smainsucount, self.sucount, - self.svaok, self.deplusneutcount, self.badcatcount, self.hyphencount, self.lonelytoecount, - self.basicreplaceecount, self.ambigcount, self.subjunctivecount, self.unknownnouncount, - self.unknownnamecount, self.dezebwcount, self.noun1c_count])) + map(str, [self.altid, self.altsent, score] + self.criteria)) therow: list = [base, user1, user2, user3] + \ ['Alternative', uttid] + 2 * [''] + part4 @@ -975,14 +947,7 @@ def oldgetuttid(stree: SynTree) -> UttId: def scorefunction(obj: Alternative) -> TupleNint: - return (-obj.unknownwordcount, -obj.wrongposwordcount,-obj.unknownnouncount, -obj.unknownnamecount, -obj.ambigcount, -obj.dpcount, - -obj.dhyphencount, -obj.mainclausecount, obj.topclause, - -obj.complsucount, -obj.badcatcount, - -obj.basicreplaceecount, -obj.ambigcount, -obj.hyphencount, -obj.lonelytoecount, - -obj.subjunctivecount, obj.smainsucount, obj.dimcount, - obj.compcount, obj.supcount, obj.compoundcount, obj.sucount, obj.svaok, - -obj.deplusneutcount, - -obj.dezebwcount, -obj.noun1c_count, -obj.penalty) + return tuple(obj.criteria) def getbestaltids(alts: Dict[AltId, Alternative]) -> List[AltId]: @@ -1064,6 +1029,33 @@ def getwrongposwordcount(nt: SynTree) -> int: result += len(matches) return result + +getdpcount = lambda nt: countav(nt, 'rel', 'dp') +getdhyphencount = lambda nt: countav(nt, 'rel', '--') +getdimcount = lambda nt: countav(nt, 'graad', 'dim') +getcompcount = lambda nt: countav(nt, 'graad', 'comp') +getsupcount = lambda nt: countav(nt, 'graad', 'sup') +getsucount = lambda nt: countav(nt, 'rel', 'su') +getbadcatcount = lambda nt: len( + [node for node in nt.xpath('.//node[@cat and (@cat="du") and node[@rel="dp"]]')]) +gethyphencount = lambda nt: len( + [node for node in nt.xpath('.//node[contains(@word, "-")]')]) +getbasicreplaceecount = lambda nt: len([node for node in nt.xpath('.//node[@word]') + if getattval(node, 'word').lower() in basicreplacements]) +getsubjunctivecount = lambda nt: len( + [node for node in nt.xpath('.//node[@pvtijd="conj"]')]) +getunknownnouncount = lambda nt: len([node for node in nt.xpath( + './/node[@pt="n" and @frame="noun(both,both,both)"]')]) +getunknownnamecount = lambda nt: len([node for node in nt.xpath( + './/node[@pt="n" and @frame="proper_name(both)"]')]) +complsuxpath = expandmacros(""".//node[node[(@rel="ld" or @rel="pc") and + @end<=../node[@rel="su"]/@begin and @begin >= ../node[@rel="hd"]/@end] and + not(node[%Rpronoun%])]""") +getcomplsucount = lambda nt: len([node for node in nt.xpath(complsuxpath)]) +getdezebwcount = lambda nt: len([node for node in nt.xpath(dezebwxpath)]) +getnoun1c_count = lambda nt: len([node for node in nt.xpath(noun1cxpath)]) + + def selectcorrection(stree: SynTree, ptmds: List[ParsedCorrection], corr: CorrectionMode) -> Tuple[ ParsedCorrection, OrigandAlts]: # to be implemented@@ @@ -1077,49 +1069,10 @@ def selectcorrection(stree: SynTree, ptmds: List[ParsedCorrection], corr: Correc for cw, nt, md in ptmds: altsent = space.join(cw) penalty = compute_penalty(md) - dpcount = countav(nt, 'rel', 'dp') - dhyphencount = countav(nt, 'rel', '--') - dimcount = countav(nt, 'graad', 'dim') - compcount = countav(nt, 'graad', 'comp') - supcount = countav(nt, 'graad', 'sup') - compoundcount = getcompoundcount(nt) - unknownwordcount = getunknownwordcount(nt) - wrongposwordcount = getwrongposwordcount(nt) - sucount = countav(nt, 'rel', 'su') - lonelytoecount = getlonelytoecount(nt) - mainclausecount = getmainclausecount(nt) - topclause = gettopclause(nt) - smainsucount = countsmainsu(nt) - svaokcount = getsvaokcount(nt) - deplusneutcount = getdeplusneutcount(nt) - badcatcount = len( - [node for node in nt.xpath('.//node[@cat and (@cat="du") and node[@rel="dp"]]')]) - hyphencount = len( - [node for node in nt.xpath('.//node[contains(@word, "-")]')]) - basicreplaceecount = len([node for node in nt.xpath('.//node[@word]') - if getattval(node, 'word').lower() in basicreplacements]) - ambigwordcount = countambigwords(nt) - subjunctivecount = len( - [node for node in nt.xpath('.//node[@pvtijd="conj"]')]) - unknownnouncount = len([node for node in nt.xpath( - './/node[@pt="n" and @frame="noun(both,both,both)"]')]) - unknownnamecount = len([node for node in nt.xpath( - './/node[@pt="n" and @frame="proper_name(both)"]')]) - complsuxpath = expandmacros(""".//node[node[(@rel="ld" or @rel="pc") and - @end<=../node[@rel="su"]/@begin and @begin >= ../node[@rel="hd"]/@end] and - not(node[%Rpronoun%])]""") - complsucount = len([node for node in nt.xpath(complsuxpath)]) - dezebwcount = len([node for node in nt.xpath(dezebwxpath)]) - noun1c_count = len([node for node in nt.xpath(noun1cxpath)]) - # overregcount but these will mostly be unknown words - # mwunamecount well maybe unknownpropernoun first - - alt = Alternative(stree, altid, altsent, penalty, dpcount, dhyphencount, mainclausecount, topclause, complsucount, dimcount, compcount, - supcount, - compoundcount, unknownwordcount, wrongposwordcount, smainsucount, sucount, svaokcount, deplusneutcount, badcatcount, - hyphencount, lonelytoecount, - basicreplaceecount, ambigwordcount, subjunctivecount, unknownnouncount, - unknownnamecount, dezebwcount, noun1c_count) + + criteriavalues = [criterion.getfunction(nt) * criterion.polarity for criterion in criteria] + + alt = Alternative(stree, altid, altsent, penalty * negative, criteriavalues) alts[altid] = alt altid += 1 orandalts = OrigandAlts(orig, alts) @@ -1127,7 +1080,6 @@ def selectcorrection(stree: SynTree, ptmds: List[ParsedCorrection], corr: Correc if corr == corr1: orandalts.selected = altid - 1 elif corr == corrn: - # @@to be implemented@@ bestaltids = getbestaltids(alts) if bestaltids != []: bestaltid = bestaltids[0] # or perhaps better the last one? @@ -1204,3 +1156,40 @@ def compute_penalty(md: List[Meta]) -> Penalty: for meta in md: totalpenalty += meta.penalty return totalpenalty + +# The constant *criteria* is a list of objects of class *Criterion* that are used, in the order given, to evaluate parses +criteria = [ + Criterion("unknownwordcount", getunknownwordcount, negative, "Number of unknown words"), + Criterion("wrongposwordcount", getwrongposwordcount, negative, "Numbe rof words with the wrong part of speech"), + Criterion("unknownnouncount", getunknownnouncount, negative, "Count of unknown nouns according to Alpino"), + Criterion("unknownnamecount", getunknownnamecount, negative, "Count of unknown names"), + Criterion("ambigcount", countambigwords, negative, "Number of ambiguous words"), + Criterion("dpcount", getdpcount, negative, "Number of nodes with relation dp"), + Criterion("dhyphencount", getdhyphencount, negative, "Number of nodes with relation --"), + Criterion("mainclausecount", getmainclausecount, negative, "Number of main clauses"), + Criterion("topclause", gettopclause, positive, "Single clause under top"), + Criterion("complsucount", getcomplsucount, negative, ""), + Criterion("badcatcount", getbadcatcount, negative, "Count of bad categories: du that contains a node with relation dp"), + Criterion("basicreplaceecount", getbasicreplaceecount, negative, "Number of words from the basic replacements"), + Criterion("hyphencount", gethyphencount, negative, "Number rof words that contain hyphens"), + Criterion("lonelytoecount", getlonelytoecount, negative, "Number of occurrences of lonely 'toe'"), + Criterion("subjunctivecount", getsubjunctivecount, negative, "Number of subjunctive verb forms"), + Criterion("smainsucount", countsmainsu, positive, "Count of smain nodes that contain a subject"), + Criterion("dimcount", getdimcount, positive, "Number of words that are diminutives"), + Criterion("compcount", getcompcount, positive, "Number of words that are comparatives"), + Criterion("supcount", getsupcount, positive, "Number of words that are superlatives"), + Criterion("compoundcount", getcompoundcount, positive, "Number of nouns that are compounds"), + Criterion("sucount", getsucount, positive, "Number of subjects"), + Criterion("svaok", getsvaokcount, positive, "Numbe rof time subject verb agreement is OK"), + Criterion("deplusneutcount", getdeplusneutcount, negative, "Number of deviant configuratios with de-determeine + neuiter noun"), + Criterion("dezebwcount", getdezebwcount, negative, "Count of 'deze' as adverb"), + Criterion("noun1c_count", getnoun1c_count, negative, "Number of nouns that consist of a single character"), + # Criterion("penalty", compute_penalty, negative, "Penalty for the changes made") # not in here, added later in Alternative +] + +altpropertiesheader = [criterion.name for criterion in criteria] + ['penalty'] + +errorwbheader = ['Sample', 'User1', 'User2', 'User3'] + \ + ['Status', 'Uttid', 'Origutt', 'Origsent'] + \ + ['altid', 'altsent', 'score'] + \ + altpropertiesheader