From dbb3418e64f5928b0018c6022bf9e46e3a950305 Mon Sep 17 00:00:00 2001 From: Fabian Gruenewald Date: Fri, 21 Jun 2024 11:49:01 +0200 Subject: [PATCH] use hashes consistently and add test when skipping filter --- polyply/src/generate_templates.py | 9 ++++-- polyply/tests/test_generate_templates.py | 37 +++++++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/polyply/src/generate_templates.py b/polyply/src/generate_templates.py index 86dc5cd8..ca7df079 100644 --- a/polyply/src/generate_templates.py +++ b/polyply/src/generate_templates.py @@ -34,8 +34,13 @@ def _extract_template_graphs(meta_molecule, template_graphs={}, skip_filter=Fals if skip_filter: for node in meta_molecule.nodes: resname = meta_molecule.nodes[node]["resname"] - if resname not in template_graphs: - template_graphs[resname] = meta_molecule.nodes[node]["graph"] + graph = meta_molecule.nodes[node]["graph"] + graph_hash = nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(graph, node_attr='atomname') + if resname in template_graphs: + template_graphs[graph_hash] = graph + del template_graphs[resname] + elif resname not in template_graphs and graph_hash not in template_graphs: + template_graphs[graph_hash] = graph else: template_graphs = group_residues_by_hash(meta_molecule, template_graphs) return template_graphs diff --git a/polyply/tests/test_generate_templates.py b/polyply/tests/test_generate_templates.py index edd57073..87ce7830 100644 --- a/polyply/tests/test_generate_templates.py +++ b/polyply/tests/test_generate_templates.py @@ -31,7 +31,9 @@ _relabel_interaction_atoms, compute_volume, map_from_CoG, extract_block, GenerateTemplates, - find_interaction_involving) + find_interaction_involving, + _extract_template_graphs) +from .example_fixtures import example_meta_molecule class TestGenTemps: @@ -307,3 +309,36 @@ def test_compute_volume(lines, coords, volume): assert np.isclose(new_vol, volume, atol=0.000001) +@pytest.mark.parametrize('resnames, gen_template_graphs, skip_filter', ( + # two different residues no template_graphs + (['A', 'B', 'A'], [], False), + # two different residues no template_graphs + (['A', 'B', 'A'], [], True), + # two different residues one template_graphs + (['A', 'B', 'A'], [1], True), + # two different residues one template_graphs + (['A', 'B', 'A'], [1], False), +)) +def test_extract_template_graphs(example_meta_molecule, resnames, gen_template_graphs, skip_filter): + # set the residue names + for resname, node in zip(resnames, example_meta_molecule.nodes): + example_meta_molecule.nodes[node]['resname'] = resname + nx.set_node_attributes(example_meta_molecule.nodes[node]['graph'], resname, 'resname') + + # extract template graphs if needed + template_graphs = {} + for node in gen_template_graphs: + graph = example_meta_molecule.nodes[node]['graph'] + nx.set_node_attributes(graph, True, 'template') + template_graphs[example_meta_molecule.nodes[node]['resname']] = graph + + # perfrom the grouping + unique_graphs = _extract_template_graphs(example_meta_molecule, template_graphs, skip_filter) + + # check the outcome + assert len(unique_graphs) == 2 + + for graph in template_graphs.values(): + graph_hash = nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(graph, node_attr='atomname') + templated = list(nx.get_node_attributes(unique_graphs[graph_hash], 'template').values()) + assert all(templated)