Skip to content

Commit

Permalink
Adaptations for the form and the annotation input
Browse files Browse the repository at this point in the history
  • Loading branch information
JanOdijk committed Oct 31, 2023
1 parent da43ee4 commit 7dc410a
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 258 deletions.
311 changes: 53 additions & 258 deletions src/sastadev/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,41 +148,35 @@
from lxml import etree

from sastadev import compounds
from sastadev.ASTApostfunctions import getastamaxsamplesizeuttidsandcutoff
from sastadev.allresults import AllResults, ExactResultsDict, MatchesDict, ResultsKey, mkresultskey, showreskey, \
scores2counts
from sastadev.asta_queries import astalemmafunction
from sastadev.conf import settings
from sastadev.constants import (bronzefolder, formsfolder, intreebanksfolder,
loggingfolder, outtreebanksfolder,
resultsfolder, silverfolder, silverpermfolder)
from sastadev.correcttreebank import (correcttreebank, corrn, errorwbheader,
validcorroptions)
from sastadev.correcttreebank import (corrn, errorwbheader, validcorroptions)
from sastadev.counterfunctions import counter2liststr
from sastadev.external_functions import str2functionmap
from sastadev.goldcountreader import get_goldcounts
from sastadev.macros import expandmacros
from sastadev.methods import astamethods, Method, SampleSize, defaultfilters, stapmethods, supported_methods, \
tarspmethods, treatmethod
from sastadev.mismatches import exactmismatches, getmarkposition, literalmissedmatches
from sastadev.methods import supported_methods, treatmethod
from sastadev.mismatches import exactmismatches, literalmissedmatches
from sastadev.mksilver import getsilverannotations, permprefix
from sastadev.query import (Query, form_process, is_core, is_literal, is_pre, is_preorcore,
from sastadev.query import (Query, is_preorcore,
post_process, query_exists, query_inform)
from sastadev.readmethod import itemseppattern, read_method
from sastadev.reduceresults import exact2results, reduceallresults, reduceexactgoldscores, reduceresults
from sastadev.reduceresults import exact2results, reduceexactgoldscores, reduceresults
from sastadev.rpf1 import getevalscores, getscores, sumfreq
from sastadev.SAFreader import (get_golddata, richexact2global,
richscores2scores)
from sastadev.sastacore import dopostqueries, getreskey, isxpathquery, SastaCoreParameters, sastacore
from sastadev.SRFreader import read_referencefile
from sastadev.sasta_explanation import finalexplanation_adapttreebank
from sastadev.sastatypes import (AltCodeDict, ExactResult, ExactResults, ExactResultsDict, FileName,
GoldTuple, Match, Matches, MatchesDict, MethodName, Position, PositionStr, QId,
QIdCount, QueryDict, ResultsCounter, ResultsDict,
ResultsDict, SampleSizeTuple, SynTree, UttId)
from sastadev.stringfunctions import getallrealwords
from sastadev.targets import get_mustbedone, get_targets
from sastadev.treebankfunctions import (getattval, getnodeendmap, getuttid,
getxmetatreepositions, getxselseuttid,
from sastadev.sastatypes import (AltCodeDict, ExactResultsDict, FileName,
GoldTuple, MatchesDict, MethodName, QId,
QIdCount, QueryDict, ResultsCounter,
SynTree, UttId)
from sastadev.treebankfunctions import (getattval, getuttid,
getxmetatreepositions,
getyield, showtree)
from sastadev.xlsx import mkworkbook

Expand Down Expand Up @@ -419,94 +413,6 @@ def getcompounds(syntree: SynTree) -> List[SynTree]:
return results


def isxpathquery(query: str) -> bool:
cleanquery = query.lstrip()
return cleanquery.startswith('//')


def getreskey(qid: QId, m: SynTree, queries: QueryDict) -> ResultsKey:
if m is None:
return mkresultskey(qid)
thequery = queries[qid]
if is_literal(thequery):
litfunc = str2functionmap[thequery.literal]
thevalue = litfunc(m)
return mkresultskey(qid, thevalue)
else:
return mkresultskey(qid)


def doqueries(syntree: SynTree, queries: QueryDict, exactresults: ExactResultsDict, allmatches: MatchesDict,
criterion: Callable[[Query], bool]):
global invalidqueries
uttid = getuttid(syntree)
# uttid = getuttidorno(syntree)
omittedwordpositions = getxmetatreepositions(syntree, 'Omitted Word', poslistname='annotatedposlist')
# print(uttid)
# core queries
junk = 0
for queryid in queries: ## @@ dit aanpassen voor literals en voor Resultskey; check read_referencefile
# if queryid not in exactresults: # not needed becaysetaken care of below
# exactresults[queryid] = []
thequeryobj = queries[queryid]
if criterion(thequeryobj):
if query_exists(thequeryobj):
thelistedquery = thequeryobj.query
if isxpathquery(thelistedquery):
expandedquery = expandmacros(thelistedquery)
thequery = "." + expandedquery
try:
matches = syntree.xpath(thequery)
except etree.XPathEvalError as e:
invalidqueries[queryid] = e
matches = []
else:
thef = str2functionmap[thelistedquery]
matches = thef(syntree)
else:
matches = []
exactresults[mkresultskey(queryid)] = []
# matchingids = [uttid for x in matches]
for m in matches:
# showtree(m)
reskey = getreskey(queryid, m, queries)
if m is None:
showtree(syntree, text='in doqueries: Nonematch')
if (reskey, uttid) in allmatches:
allmatches[(reskey, uttid)].append((m, syntree))
else:
allmatches[(reskey, uttid)] = [(m, syntree)]
exactresult = (uttid, int(getattval(m, 'begin')) + 1)
if reskey in exactresults:
exactresults[reskey].append(exactresult)
else:
exactresults[reskey] = [exactresult]
# if queryid in results:
# results[queryid].update(matchingids)
# else:
# results[queryid] = Counter(matchingids)


def docorequeries(syntree: SynTree, queries: QueryDict, results: ExactResultsDict, allmatches: MatchesDict):
doqueries(syntree, queries, results, allmatches, is_core)


def doprequeries(syntree: SynTree, queries: QueryDict, results: ExactResultsDict, allmatches: MatchesDict):
doqueries(syntree, queries, results, allmatches, is_pre)


def dopostqueries(allresults: AllResults, postquerylist: List[QId], queries: QueryDict):
# post queries
for queryid in postquerylist:
thequeryobj = queries[queryid]
if query_exists(thequeryobj):
thelistedquery = thequeryobj.query

# it is assumed that these are all python functions
thef = str2functionmap[thelistedquery]
result = thef(allresults, queries)
allresults.postresults[queryid] = result


def codeadapt(c: str) -> str:
result = c
Expand Down Expand Up @@ -736,65 +642,11 @@ def getexactresults(allmatches: MatchesDict) -> ExactResultsDict:
return result


def adaptpositions(rawexactresults: ExactResultsDict, nodeendmap) -> ExactResultsDict:
newexactresults: ExactResultsDict = {}
for qid in rawexactresults:
newlist = []
for (uttid, position) in rawexactresults[qid]:
newposition = getmarkposition(position, nodeendmap, uttid)
newtuple = (uttid, newposition)
newlist.append(newtuple)
newexactresults[qid] = newlist
return newexactresults


def passfilter(rawexactresults: ExactResultsDict, method: Method) -> ExactResultsDict:
"""
let only those through that satisfy the filter
:param rawexactresults: dictionary with ResultsKey as key and a Counter as value, exact results
:param method: Method object
:return: a filtered version of rawexactresults: results that pass the filter
"""
# exactresults: ExactResultsDict = defaultdict(list) # hiermee ontstaat een probleem: dictionary size changed in iteration
exactresults: ExactResultsDict = {}
queries = method.queries
for reskey in rawexactresults:
queryid = reskey[0]
query = queries[queryid]
queryfilter = query.filter
thefilter = method.defaultfilter if queryfilter is None or queryfilter == '' else str2functionmap[queryfilter]
exactresults[reskey] = [r for r in rawexactresults[reskey] if reskey in rawexactresults and
thefilter(query, rawexactresults, r)]
return exactresults


def getmaxsamplesizeuttidsandcutoff(allresults: AllResults) -> Tuple[List[UttId], int, Position]:
cutoffpoint = None
words = getallrealwords(allresults)
cumwordcount = 0
wordcounts: Dict[UttId, Tuple[int, int, int]] = {}
uttidlist = []
for uttid in allresults.allutts:
basewordcount = sum(words[uttid].values())
ignorewordcount = 0 # getignorewordcount(allresults, uttid)
wordcount = basewordcount - ignorewordcount
wordcounts[uttid] = (basewordcount, ignorewordcount, wordcount)
uttidlist.append(uttid)
cumwordcount += wordcount
result = (uttidlist, cumwordcount, cutoffpoint)
return result


def getsamplesizefunction(methodname: MethodName) -> Callable:
if methodname in astamethods:
result = getastamaxsamplesizeuttidsandcutoff
elif methodname in tarspmethods:
# @@to be implemented
result = getmaxsamplesizeuttidsandcutoff
elif methodname in stapmethods:
# @@to be implemented
result = getmaxsamplesizeuttidsandcutoff
return result





# defaulttarsp = r"TARSP Index Current.xlsx"
Expand Down Expand Up @@ -1012,102 +864,62 @@ def main():
analysedtrees: List[SynTree] = []
nodeendmap = {}

# @vanaf nu gaat het om een treebank, dus hier een if statement toevoegen-done
if annotationinput:
allutts, richexactscores = get_golddata(options.infilename, themethod.item2idmap, altcodes, themethod.queries,
options.includeimplies)
allutts, richexactscores = get_golddata(options.infilename, themethod.item2idmap, themethod.altcodes,
themethod.queries, options.includeimplies)
uttcount = len(allutts)
exactresults = richscores2scores(richexactscores)
annotatedfileresults = AllResults(uttcount=uttcount,
coreresults={},
exactresults=exactresults,
postresults={},
allmatches={},
filename=options.infilename,
analysedtrees=[],
allutts=allutts,
annotationinput=annotationinput)
origtreebank = None
else:
tree = etree.parse(options.infilename)
origtreebank = tree.getroot()
annotatedfileresults = None
if origtreebank.tag != 'treebank':
settings.LOGGER.error("Input treebank file does not contain a treebank element")
exit(-1)
allutts = {}
uttcount = 0
# determine targets
targets = get_targets(origtreebank)

# for tree in origtreebank:
# showtree(tree, 'voor fexplanations')

# deal with final explanations
fexplanations = True
if fexplanations:
treebank1 = finalexplanation_adapttreebank(origtreebank)
else:
treebank1 = origtreebank

# for tree in treebank1:
# showtree(tree, 'na fexplanations')
treebank, errordict, allorandalts = correcttreebank(treebank1, targets, options.methodname, options.corr)

# for tree in treebank:
# showtree(tree, 'na correcties')
scp = SastaCoreParameters(annotationinput, options.corr, themethod,
options.includeimplies, options.infilename)

# create the new treebank
fulltreebank = etree.ElementTree(treebank)
newtreebankfullname = os.path.join(outtreebankspath, corefilename + '_corrected' + '.xml')
fulltreebank.write(newtreebankfullname, encoding="UTF8", xml_declaration=False,
pretty_print=True)

# create error file
errorreportfilename = os.path.join(resultspath, corefilename + '_errorreport' + '.xlsx')
mkerrorreport(errordict, errorreportfilename)

# create error logging
errorloggingfullname = os.path.join(loggingpath, corefilename + '_errorlogging' + '.xlsx')

allerrorrows: List[str] = []
for orandalts in allorandalts:
if orandalts is not None:
allerrorrows += orandalts.OrigandAlts2rows(corefilename)
errorwb = mkworkbook(errorloggingfullname, [errorwbheader], allerrorrows, freeze_panes=(1, 1))
errorwb.close()
allresults, treebank, errordict, allorandalts, samplesizetuple = sastacore(origtreebank, annotatedfileresults, scp)

# analysedtrees consists of (uttid, syntree) pairs in the order in which they come in
analysedtrees: List[(UttId, SynTree)] = []
for syntree in treebank:
temputtid = getuttid(syntree)
uttcount += 1
exactresults = allresults.exactresults
allutts = allresults.allutts
uttcount = allresults.uttcount
allmatches = allresults.allmatches

# if temputtid == '118':
# showtree(syntree, 'tree 118')
# settings.LOGGER.error('uttcount={}'.format(uttcount))
mustbedone = get_mustbedone(syntree, targets)
if mustbedone:
# uttid = getuttid(syntree)
# analysedtrees consists of (uttid, syntree) pairs in order
uttid = getxselseuttid(syntree)
analysedtrees.append((uttid, syntree))

doprequeries(syntree, themethod.queries, rawexactresults, allmatches)
docorequeries(syntree, themethod.queries, rawexactresults, allmatches)

# showtree(syntree)
if uttid in nodeendmap:
settings.LOGGER.error('Duplicate uttid in sample: {}'.format(uttid))
nodeendmap[uttid] = getnodeendmap(syntree)

# uttno = getuttno(syntree)
# allutts[uttno] = getyield(syntree)
allutts[uttid] = getyield(syntree)

# determine exactresults and apply the filter to catch interdependencies between prequeries and corequeries
# rawexactresults = getexactresults(allmatches)
rawexactresults2 = passfilter(rawexactresults, themethod)
exactresults = adaptpositions(rawexactresults2, nodeendmap)

# pas hier de allutts en de rawexactresults2 aan om expansies te ontdoen, gebseerd op de nodeendmap
# @@to be implemented @@ of misschien in de loop hierboven al?
# create the new treebank
if treebank is not None:
fulltreebank = etree.ElementTree(treebank)
newtreebankfullname = os.path.join(outtreebankspath, corefilename + '_corrected' + '.xml')
fulltreebank.write(newtreebankfullname, encoding="UTF8", xml_declaration=False,
pretty_print=True)

# @ en vanaf hier kan het weer gemeenschappelijk worden; er met dus ook voor de annotatiefile een exactresults opgeleverd worden
# @d epostfunctions for lemma's etc moeten mogelijk wel aangepast worden
# create error file
errorreportfilename = os.path.join(resultspath, corefilename + '_errorreport' + '.xlsx')
mkerrorreport(errordict, errorreportfilename)

# adapt the exactresults positions to the reference
# create error logging
errorloggingfullname = os.path.join(loggingpath, corefilename + '_errorlogging' + '.xlsx')

coreresults = exact2results(exactresults)
allerrorrows: List[str] = []
for orandalts in allorandalts:
if orandalts is not None:
allerrorrows += orandalts.OrigandAlts2rows(corefilename)
errorwb = mkworkbook(errorloggingfullname, [errorwbheader], allerrorrows, freeze_panes=(1, 1))
errorwb.close()

platinuminfilefound = False
if os.path.exists(options.platinuminfilename):
Expand All @@ -1123,20 +935,6 @@ def main():
# platinumcheckfilename = base + platinumchecksuffix + txtext
platinumcheckfile = open(platinumcheckfilename, 'w', encoding='utf8')

postresults: Dict[ResultsKey, Any] = {}
allresults = AllResults(uttcount, coreresults, exactresults, postresults, allmatches, options.infilename,
analysedtrees,
allutts, annotationinput)

samplesizefunction = getsamplesizefunction(options.methodname)
samplesizetuple: SampleSizeTuple = samplesizefunction(allresults)

postquerylist: List[QId] = [q for q in themethod.postquerylist if themethod.queries[q].process == post_process]
formquerylist: List[QId] = [q for q in themethod.postquerylist if themethod.queries[q].process == form_process]

# we assume the reduction must be done before the postqueries
allresults = reduceallresults(allresults, samplesizetuple, options.methodname)

# bronze reduction
exactgoldscores = reduceexactgoldscores(exactgoldscores, samplesizetuple, options.methodname) # ongoing
goldscores = exact2results(exactgoldscores) # ongoing
Expand All @@ -1145,10 +943,6 @@ def main():
# silver / platinumreduction
platinumresults: Dict[ResultsKey, Counter] = reduceresults(platinumresults, samplesizetuple, options.methodname)

dopostqueries(allresults, postquerylist, themethod.queries)

dopostqueries(allresults, formquerylist, themethod.queries)

(base, ext) = os.path.splitext(options.infilename)
outputfullname = os.path.join(resultspath, corefilename + "_analysis" + tsvext + txtext)
outfile = open(outputfullname, 'w', encoding='utf8')
Expand Down Expand Up @@ -1265,6 +1059,7 @@ def main():
allgoldresults = AllResults(uttcount, goldcounters, exactgoldscores, goldpostresults, allgoldmatches, reffilename,
[],
allannutts, annotationinput)
postquerylist: List[QId] = [q for q in themethod.postquerylist if themethod.queries[q].process == post_process]
dopostqueries(allgoldresults, postquerylist, themethod.queries)

# compute the platinum postresults
Expand Down
Loading

0 comments on commit 7dc410a

Please sign in to comment.