diff --git a/mwe_query/canonicalform.py b/mwe_query/canonicalform.py index 585058c..361648f 100644 --- a/mwe_query/canonicalform.py +++ b/mwe_query/canonicalform.py @@ -3,7 +3,7 @@ to generate queries from them and to search using these queries. """ -from typing import Dict, Iterable, List, Sequence, Optional, Set, Tuple, TypeVar +from typing import cast, Dict, Iterable, List, Sequence, Optional, Set, Tuple, TypeVar from sastadev.sastatypes import SynTree import re import sys @@ -328,7 +328,9 @@ def all_leaves(stree: SynTree, annotations: List[Annotation], allowedannotations def headmodifiable(stree: SynTree, mwetop: int, annotations: List[int]): head = getchild(stree, 'hd') - if terminal(head): + if head is None: + return False + elif terminal(head): beginint = int(gav(head, 'begin')) if 0 <= beginint < len(annotations): if mwetop == notop: @@ -397,6 +399,8 @@ def zullenheadclause(stree: SynTree) -> bool: if stree.tag == 'node': cat = gav(stree, 'cat') head = getchild(stree, 'hd') + if head is None: + return False headlemma = gav(head, 'lemma') headpt = gav(head, 'pt') result = cat in { @@ -456,6 +460,8 @@ def transformtree(stree: SynTree, annotations: List[Annotation], mwetop=notop, a return results elif cat in {'smain', 'sv1'}: head = getchild(stree, 'hd') + if head is None: + return [] lemma = gav(head, 'lemma') vc = getchild(stree, 'vc') # predm, if present, must be moved downwards here @@ -938,8 +944,9 @@ def lowerpredm(stree: SynTree) -> SynTree: for predmnodeid in predmnodeids: predmnode = find1(newstree, f'.//node[@id="{predmnodeid}"]') predmparent = predmnode.getparent() - predmparent.remove(predmnode) - lowestvcnode.append(predmnode) + if predmparent is not None: + predmparent.remove(predmnode) + lowestvcnode.append(predmnode) # print('lowerpredm: newstree') # ET.dump(newstree) return newstree @@ -966,7 +973,8 @@ def newgenvariants(stree: SynTree) -> List[SynTree]: vblsu = find1(newstree, f'.//node[@rel="su" and {vblnode}]') if vblsu is not None: parent = vblsu.getparent() - parent.remove(vblsu) + if parent is not None: + parent.remove(vblsu) # move predm down not needed already done in transformtree # newstree = lowerpredm(newstree) @@ -997,10 +1005,11 @@ def newgenvariants(stree: SynTree) -> List[SynTree]: newvcnode1 = nodecopy(vcnode) newvcnode2 = nodecopy(vcnode) parent = obj1node.getparent() - parent.remove(obj1node) - alternativesnode = mkalternativesnode( - [[obj1node], [newvcnode1], [newpobj1node, newvcnode2]]) - parent.append(alternativesnode) + if parent is not None: + parent.remove(obj1node) + alternativesnode = mkalternativesnode( + [[obj1node], [newvcnode1], [newpobj1node, newvcnode2]]) + parent.append(alternativesnode) vblppnodeids = globalresult.xpath(vblppnodeidxpath) for vblppnodeid in vblppnodeids: @@ -1008,6 +1017,8 @@ def newgenvariants(stree: SynTree) -> List[SynTree]: newpobj1node1 = nodecopy(pobj1node) newvcnode1 = nodecopy(vcnode) parent = ppnode.getparent() + if parent is None: + continue parent.remove(ppnode) newppnode1 = copy.copy(ppnode) for child in newppnode1: @@ -1395,7 +1406,7 @@ def relpronsubst(stree: SynTree) -> SynTree: if govprep is not None: govprep.attrib['vztype'] = 'init' govprep.attrib['lemma'] = adaptvzlemma_inv( - govprep.attrib['lemma']) + cast(str, govprep.attrib['lemma'])) # ET.dump(newstree) elif rhdframe.startswith('waar_adverb'): diff --git a/mwe_query/mwestats.py b/mwe_query/mwestats.py index 1344d4e..ba060e8 100644 --- a/mwe_query/mwestats.py +++ b/mwe_query/mwestats.py @@ -537,7 +537,7 @@ def displayfullstats(stats: MWEstats, outfile, header=''): rows: List[str] = [] for clemmas, cwords, utt in compliststats.data: rows.append(f'{clemmas}: {cwords}: {utt}'.strip()) - + rows.sort() for row in rows: @@ -570,7 +570,6 @@ def displayfullstats(stats: MWEstats, outfile, header=''): for row in rows: print(row, file=outfile) - allcompnodes = stats.compnodes modstats = stats.modstats displaystats('Modification', modstats, outfile) diff --git a/tests/update_outputs.py b/tests/update_outputs.py index 8911d35..3d312af 100755 --- a/tests/update_outputs.py +++ b/tests/update_outputs.py @@ -5,7 +5,7 @@ from alpino_query import parse_sentence # type: ignore import sys -from os import path +from os import listdir, path import glob import lxml.etree as ET @@ -15,7 +15,19 @@ # import this implementation sys.path.insert(0, path.join(testdir, "..")) from mwe_query import Mwe -from mwe_query.canonicalform import preprocess_MWE, transformtree +from mwe_query.canonicalform import ( + preprocess_MWE, + transformtree, + generatemwestructures, + generatequeries, + applyqueries, +) + +from mwe_query.mwestats import ( + displayfullstats, + getstats, + gettreebank, +) def datapath(dirname, filename): @@ -57,11 +69,51 @@ def update_generate(basename): def gettopnode(stree): for child in stree: - if child.tag == 'node': + if child.tag == "node": return child return None +def update_full_mwe_stats(treebank_name: str, mwe: str): + dotbfolder = datapath("mwetreebanks", treebank_name) + rawtreebankfilenames = listdir(dotbfolder) + selcond = lambda _: True + treebankfilenames = [ + path.join(dotbfolder, fn) + for fn in rawtreebankfilenames + if fn[-4:] == ".xml" and selcond(fn) + ] + treebank = gettreebank(treebankfilenames) + + mwestructures = generatemwestructures(mwe) + for i, mweparse in enumerate(mwestructures): + mwequery, nearmissquery, supersetquery = generatequeries(mwe) + queryresults = applyqueries( + treebank, mwe, mwequery, nearmissquery, supersetquery, verbose=False + ) + + fullmwestats = getstats(mwe, queryresults, treebank) + + filename = f"full_mwe_stats_{treebank_name}_{i}.txt" + outputfilename = datapath(path.join("mwetreebanks", "expected"), filename) + + with open(outputfilename, "w", encoding="utf8") as outfile: + + displayfullstats( + fullmwestats.mwestats, outfile, header="*****MWE statistics*****" + ) + displayfullstats( + fullmwestats.nearmissstats, + outfile, + header="*****Near-miss statistics*****", + ) + displayfullstats( + fullmwestats.diffstats, + outfile, + header="*****Near-miss - MWE statistics*****", + ) + + def update_transform(): mwes = read("transform", "mwes.txt").splitlines() @@ -82,8 +134,11 @@ def update_transform(): i += 1 -input_files = glob.glob(path.join(datadir, "generate", '*.txt')) +input_files = glob.glob(path.join(datadir, "generate", "*.txt")) for input in input_files: head, ext = path.splitext(path.basename(input)) update_generate(head) + +update_full_mwe_stats("dansontspringena", "iemand zal de dans ontspringen") +update_full_mwe_stats("hartbreken", "iemand zal iemands hart breken") update_transform()