Skip to content

Commit

Permalink
Clean-up and allow updating test data
Browse files Browse the repository at this point in the history
  • Loading branch information
oktaal committed Mar 8, 2024
1 parent fcdc1f5 commit 050b1c0
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 16 deletions.
31 changes: 21 additions & 10 deletions mwe_query/canonicalform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
to generate queries from them and to search using these queries.
"""

from typing import Dict, Iterable, List, Sequence, Optional, Set, Tuple, TypeVar
from typing import cast, Dict, Iterable, List, Sequence, Optional, Set, Tuple, TypeVar
from sastadev.sastatypes import SynTree
import re
import sys
Expand Down Expand Up @@ -328,7 +328,9 @@ def all_leaves(stree: SynTree, annotations: List[Annotation], allowedannotations

def headmodifiable(stree: SynTree, mwetop: int, annotations: List[int]):
head = getchild(stree, 'hd')
if terminal(head):
if head is None:
return False
elif terminal(head):
beginint = int(gav(head, 'begin'))
if 0 <= beginint < len(annotations):
if mwetop == notop:
Expand Down Expand Up @@ -397,6 +399,8 @@ def zullenheadclause(stree: SynTree) -> bool:
if stree.tag == 'node':
cat = gav(stree, 'cat')
head = getchild(stree, 'hd')
if head is None:
return False
headlemma = gav(head, 'lemma')
headpt = gav(head, 'pt')
result = cat in {
Expand Down Expand Up @@ -456,6 +460,8 @@ def transformtree(stree: SynTree, annotations: List[Annotation], mwetop=notop, a
return results
elif cat in {'smain', 'sv1'}:
head = getchild(stree, 'hd')
if head is None:
return []
lemma = gav(head, 'lemma')
vc = getchild(stree, 'vc')
# predm, if present, must be moved downwards here
Expand Down Expand Up @@ -938,8 +944,9 @@ def lowerpredm(stree: SynTree) -> SynTree:
for predmnodeid in predmnodeids:
predmnode = find1(newstree, f'.//node[@id="{predmnodeid}"]')
predmparent = predmnode.getparent()
predmparent.remove(predmnode)
lowestvcnode.append(predmnode)
if predmparent is not None:
predmparent.remove(predmnode)
lowestvcnode.append(predmnode)
# print('lowerpredm: newstree')
# ET.dump(newstree)
return newstree
Expand All @@ -966,7 +973,8 @@ def newgenvariants(stree: SynTree) -> List[SynTree]:
vblsu = find1(newstree, f'.//node[@rel="su" and {vblnode}]')
if vblsu is not None:
parent = vblsu.getparent()
parent.remove(vblsu)
if parent is not None:
parent.remove(vblsu)

# move predm down not needed already done in transformtree
# newstree = lowerpredm(newstree)
Expand Down Expand Up @@ -997,17 +1005,20 @@ def newgenvariants(stree: SynTree) -> List[SynTree]:
newvcnode1 = nodecopy(vcnode)
newvcnode2 = nodecopy(vcnode)
parent = obj1node.getparent()
parent.remove(obj1node)
alternativesnode = mkalternativesnode(
[[obj1node], [newvcnode1], [newpobj1node, newvcnode2]])
parent.append(alternativesnode)
if parent is not None:
parent.remove(obj1node)
alternativesnode = mkalternativesnode(
[[obj1node], [newvcnode1], [newpobj1node, newvcnode2]])
parent.append(alternativesnode)

vblppnodeids = globalresult.xpath(vblppnodeidxpath)
for vblppnodeid in vblppnodeids:
ppnode = find1(newstree, f'//node[@id="{vblppnodeid}"]')
newpobj1node1 = nodecopy(pobj1node)
newvcnode1 = nodecopy(vcnode)
parent = ppnode.getparent()
if parent is None:
continue
parent.remove(ppnode)
newppnode1 = copy.copy(ppnode)
for child in newppnode1:
Expand Down Expand Up @@ -1395,7 +1406,7 @@ def relpronsubst(stree: SynTree) -> SynTree:
if govprep is not None:
govprep.attrib['vztype'] = 'init'
govprep.attrib['lemma'] = adaptvzlemma_inv(
govprep.attrib['lemma'])
cast(str, govprep.attrib['lemma']))
# ET.dump(newstree)

elif rhdframe.startswith('waar_adverb'):
Expand Down
3 changes: 1 addition & 2 deletions mwe_query/mwestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def displayfullstats(stats: MWEstats, outfile, header=''):
rows: List[str] = []
for clemmas, cwords, utt in compliststats.data:
rows.append(f'{clemmas}: {cwords}: {utt}'.strip())

rows.sort()

for row in rows:
Expand Down Expand Up @@ -570,7 +570,6 @@ def displayfullstats(stats: MWEstats, outfile, header=''):
for row in rows:
print(row, file=outfile)

allcompnodes = stats.compnodes
modstats = stats.modstats
displaystats('Modification', modstats, outfile)

Expand Down
63 changes: 59 additions & 4 deletions tests/update_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from alpino_query import parse_sentence # type: ignore
import sys
from os import path
from os import listdir, path
import glob
import lxml.etree as ET

Expand All @@ -15,7 +15,19 @@
# import this implementation
sys.path.insert(0, path.join(testdir, ".."))
from mwe_query import Mwe
from mwe_query.canonicalform import preprocess_MWE, transformtree
from mwe_query.canonicalform import (
preprocess_MWE,
transformtree,
generatemwestructures,
generatequeries,
applyqueries,
)

from mwe_query.mwestats import (
displayfullstats,
getstats,
gettreebank,
)


def datapath(dirname, filename):
Expand Down Expand Up @@ -57,11 +69,51 @@ def update_generate(basename):

def gettopnode(stree):
for child in stree:
if child.tag == 'node':
if child.tag == "node":
return child
return None


def update_full_mwe_stats(treebank_name: str, mwe: str):
dotbfolder = datapath("mwetreebanks", treebank_name)
rawtreebankfilenames = listdir(dotbfolder)
selcond = lambda _: True
treebankfilenames = [
path.join(dotbfolder, fn)
for fn in rawtreebankfilenames
if fn[-4:] == ".xml" and selcond(fn)
]
treebank = gettreebank(treebankfilenames)

mwestructures = generatemwestructures(mwe)
for i, mweparse in enumerate(mwestructures):
mwequery, nearmissquery, supersetquery = generatequeries(mwe)
queryresults = applyqueries(
treebank, mwe, mwequery, nearmissquery, supersetquery, verbose=False
)

fullmwestats = getstats(mwe, queryresults, treebank)

filename = f"full_mwe_stats_{treebank_name}_{i}.txt"
outputfilename = datapath(path.join("mwetreebanks", "expected"), filename)

with open(outputfilename, "w", encoding="utf8") as outfile:

displayfullstats(
fullmwestats.mwestats, outfile, header="*****MWE statistics*****"
)
displayfullstats(
fullmwestats.nearmissstats,
outfile,
header="*****Near-miss statistics*****",
)
displayfullstats(
fullmwestats.diffstats,
outfile,
header="*****Near-miss - MWE statistics*****",
)


def update_transform():
mwes = read("transform", "mwes.txt").splitlines()

Expand All @@ -82,8 +134,11 @@ def update_transform():
i += 1


input_files = glob.glob(path.join(datadir, "generate", '*.txt'))
input_files = glob.glob(path.join(datadir, "generate", "*.txt"))
for input in input_files:
head, ext = path.splitext(path.basename(input))
update_generate(head)

update_full_mwe_stats("dansontspringena", "iemand zal de dans ontspringen")
update_full_mwe_stats("hartbreken", "iemand zal iemands hart breken")
update_transform()

0 comments on commit 050b1c0

Please sign in to comment.