Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documenting issues with recovering changes to level2 graph from edit logs #509

Open
bdpedigo opened this issue Oct 2, 2024 · 0 comments

Comments

@bdpedigo
Copy link
Contributor

bdpedigo commented Oct 2, 2024

I think this issue is already known, but wanted to clearly document it and figure out what the best path for doing this in the future will be.

My ultimate goal is to efficiently find how edits affect the level2 graph. This is currently not easy from the information provided in the edit logs for a couple of reasons.

First, for merges, only a single supervoxel edge that was added is stored.

import datetime
from collections import namedtuple
from typing import Optional

import numpy as np
import pandas as pd
import pytz
from requests import HTTPError

from caveclient import CAVEclient

client = CAVEclient("minnie65_phase3_v1")

root_id = 864691135639556411

change_log = client.chunkedgraph.get_tabular_change_log(root_id)[root_id]

change_log = pd.DataFrame(change_log).set_index("operation_id")

merge_id = change_log.index[0]
split_id = change_log.index[1]

details = client.chunkedgraph.get_operation_details([merge_id, split_id])

merge_details = details[str(merge_id)]
added_supervoxel_edges = details[str(merge_id)]["added_edges"]
print(added_supervoxel_edges)
[[89385966883691061, 89385966883698825]]

For this operation, there is one supervoxel edge that is stored as "added".

We can map these modified nodes to their level2 ids at the time of the operation.

supervoxels_affected = np.unique(
    np.concatenate([list(edge) for edge in added_supervoxel_edges])
)

merge_row = change_log.loc[merge_id]
ts = merge_row["timestamp"]

timestamp = datetime.datetime.fromtimestamp(ts / 1000, pytz.UTC)
delta = datetime.timedelta(microseconds=1)

pre_l2_ids = client.chunkedgraph.get_roots(
    supervoxels_affected, stop_layer=2, timestamp=timestamp - delta
)
post_l2_ids = client.chunkedgraph.get_roots(
    supervoxels_affected, stop_layer=2, timestamp=timestamp + delta
)

print("Supervoxel IDs:", supervoxels_affected)
print("L2 IDs pre operation:", pre_l2_ids)
print("L2 IDs post operation:", post_l2_ids)
Supervoxel IDs: [89385966883691061 89385966883698825]
L2 IDs pre operation: [161443560921497602 161443560921497601]
L2 IDs post operation: [161443560921498196 161443560921498196]

Now, can compare this to some brute force code I wrote to explicitly compare the level2 graph before and after an edit. It uses the bounds argument to make this much much faster than pulling the whole level2 graph.

NetworkDelta = namedtuple(
    "NetworkDelta",
    [
        "removed_nodes",
        "added_nodes",
        "removed_edges",
        "added_edges",
        "metadata",
    ],
)


def _get_changed_edges(
    before_edges: pd.DataFrame, after_edges: pd.DataFrame
) -> tuple[pd.DataFrame, pd.DataFrame]:
    before_edges.drop_duplicates()
    before_edges["is_before"] = True
    after_edges.drop_duplicates()
    after_edges["is_before"] = False
    delta_edges = pd.concat([before_edges, after_edges]).drop_duplicates(
        ["source", "target"], keep=False
    )
    removed_edges = delta_edges.query("is_before").drop(columns=["is_before"])
    added_edges = delta_edges.query("~is_before").drop(columns=["is_before"])
    return removed_edges, added_edges


def _make_bbox(
    bbox_halfwidth: int, point_in_nm: np.ndarray, seg_resolution: np.ndarray
) -> np.ndarray:
    x_center, y_center, z_center = point_in_nm

    x_start = x_center - bbox_halfwidth
    x_stop = x_center + bbox_halfwidth
    y_start = y_center - bbox_halfwidth
    y_stop = y_center + bbox_halfwidth
    z_start = z_center - bbox_halfwidth
    z_stop = z_center + bbox_halfwidth

    start_point_cg = np.array([x_start, y_start, z_start]) / seg_resolution
    stop_point_cg = np.array([x_stop, y_stop, z_stop]) / seg_resolution

    bbox_cg = np.array([start_point_cg, stop_point_cg], dtype=int)
    return bbox_cg


def _get_level2_nodes_edges(
    root_id: int, client: CAVEclient, bounds: Optional[np.ndarray] = None
):
    try:
        edgelist = client.chunkedgraph.level2_chunk_graph(root_id, bounds=bounds)
        nodelist = set()
        for edge in edgelist:
            for node in edge:
                nodelist.add(node)
        nodelist = list(nodelist)
    except HTTPError:
        # REF: https://github.com/seung-lab/PyChunkedGraph/issues/404
        nodelist = client.chunkedgraph.get_leaves(root_id, stop_layer=2)
        if len(nodelist) != 1:
            raise HTTPError(
                f"HTTPError: level 2 chunk graph not found for root_id: {root_id}"
            )
        else:
            edgelist = np.empty((0, 2), dtype=int)

    nodes = pd.DataFrame(index=nodelist)

    if len(edgelist) == 0:
        edges = pd.DataFrame(columns=["source", "target"])
    else:
        edges = pd.DataFrame(edgelist, columns=["source", "target"])

    edges = edges.drop_duplicates(keep="first")

    return nodes, edges


def _get_all_nodes_edges(
    root_ids: list[int], client: CAVEclient, bounds: Optional[np.ndarray] = None
):
    all_nodes = []
    all_edges = []
    for root_id in root_ids:
        nodes, edges = _get_level2_nodes_edges(root_id, client, bounds=bounds)
        all_nodes.append(nodes)
        all_edges.append(edges)
    all_nodes = pd.concat(all_nodes, axis=0)
    all_edges = pd.concat(all_edges, axis=0, ignore_index=True)
    return all_nodes, all_edges


def get_level2_edits(
    operataion_ids: list[int],
    client: CAVEclient,
    bounds_halfwidth: int = 20_000,
    metadata: bool = True,
) -> dict[int, NetworkDelta]:
    seg_resolution = client.chunkedgraph.base_resolution

    def _get_info_for_operation(operation_id):
        row = change_log.loc[operation_id]

        before_root_ids = row["before_root_ids"]
        # after_root_ids = row["roots"]
        details = client.chunkedgraph.get_operation_details([operation_id])[
            str(operation_id)
        ]
        after_root_ids = details["roots"]

        point_in_cg = np.array(details["sink_coords"][0])

        point_in_nm = point_in_cg * seg_resolution

        if bounds_halfwidth is None:
            bbox_cg = None
        else:
            bbox_cg = _make_bbox(bounds_halfwidth, point_in_nm, seg_resolution).T

        # grabbing the union of before/after nodes/edges
        # NOTE: this is where all the compute time comes from
        all_before_nodes, all_before_edges = _get_all_nodes_edges(
            before_root_ids, client, bounds=bbox_cg
        )
        all_after_nodes, all_after_edges = _get_all_nodes_edges(
            after_root_ids, client, bounds=bbox_cg
        )

        # finding the nodes that were added or removed, simple set logic
        added_nodes_index = all_after_nodes.index.difference(all_before_nodes.index)
        added_nodes = all_after_nodes.loc[added_nodes_index]
        removed_nodes_index = all_before_nodes.index.difference(all_after_nodes.index)
        removed_nodes = all_before_nodes.loc[removed_nodes_index]

        # finding the edges that were added or removed, simple set logic again
        removed_edges, added_edges = _get_changed_edges(
            all_before_edges, all_after_edges
        )

        # keep track of what changed
        if metadata:
            metadata_dict = {
                **row.to_dict(),
                "operation_id": operation_id,
                "root_id": root_id,
                "n_added_nodes": len(added_nodes),
                "n_removed_nodes": len(removed_nodes),
                "n_modified_nodes": len(added_nodes) + len(removed_nodes),
                "n_added_edges": len(added_edges),
                "n_removed_edges": len(removed_edges),
                "n_modified_edges": len(added_edges) + len(removed_edges),
            }
        else:
            metadata_dict = {}

        return NetworkDelta(
            removed_nodes,
            added_nodes,
            removed_edges,
            added_edges,
            metadata=metadata_dict,
        )

    networkdeltas_by_operation = {}
    for operation_id in operataion_ids:
        networkdeltas_by_operation[operation_id] = _get_info_for_operation(operation_id)

    return networkdeltas_by_operation


networkdeltas = get_level2_edits([merge_id, split_id], client)

print(networkdeltas[merge_id].added_edges.values)
[[161302823433142303 161373192177320783]
 [161302892152619015 161373260896797382]
 [161373123457843248 161373192177320783]
 [161373123457843248 161443492202021454]
 [161373123457843501 161443492202021454]
 [161373192110211983 161373192177320783]
 [161373192110212021 161373192177320783]
 [161373192177320783 161373260896796678]
 [161373192177320783 161373260896797382]
 [161373192177320783 161443560921498196]
 [161373260829688874 161373260896797382]
 [161373260829688894 161373260896797382]
 [161373260896796678 161443629640974899]
 [161373260896797382 161373329616273410]
 [161373260896797382 161373329616273411]
 [161373260896797382 161443629640974899]
 [161443423482544167 161443492202021454]
 [161443423482544381 161443492202021454]
 [161443492134912980 161443492202021454]
 [161443492202021454 161443560921498196]
 [161443492202021454 161513860946198992]
 [161443560854389701 161443560921498196]
 [161443560854389734 161443560921498196]
 [161443560921498196 161443629640974899]
 [161443560921498196 161513929665675782]
 [161443629573866365 161443629640974899]
 [161443629640974735 161513998385152439]
 [161443629640974899 161443698360451073]
 [161443629640974899 161443698360451079]
 [161443629640974899 161513998385152439]
 [161513792226721823 161513860946198992]
 [161513860879090603 161513860946198992]
 [161513860946198992 161513929665675782]
 [161513860946198992 161584229690376193]
 [161513860946198992 161584229690376221]
 [161513929598567373 161513929665675782]
 [161513929598567394 161513929665675782]
 [161513929665675645 161513998385152439]
 [161513929665675782 161513998385152439]
 [161513929665675782 161584298409852929]
 [161513929665675782 161584298409852931]
 [161513929665675782 161584298409852935]
 [161513998318044173 161513998385152439]
 [161513998385152439 161514067104628737]
 [161513998385152439 161514067104628738]
 [161513998385152439 161514067104628891]
 [161513998385152439 161514067104628910]
 [161513998385152439 161514067104628966]
 [161513998385152439 161514067104629021]
 [161513998385152439 161514067104629042]
 [161513998385152439 161514067104629376]
 [161513998385152439 161584367129329665]
 [161513998385152439 161584367129329674]
 [161513998385152439 161584367129329944]]

In other words, there are many more edges (in the level2 graph) that were added than would be suggested by just mapping the added supervoxel edge to its level2 IDs before and after.

A second issue (and this one is perhaps more just me not knowing a good path for mapping these changes) - even though splits have multiple supervoxel edges written down:

split_details = details[str(split_id)]
removed_edges = split_details["removed_edges"]
print(np.array(removed_edges))
[[89323707037880032 89323707037886762]
 [89323707037882956 89323707037886762]
 [89323707037880032 89323707037886762]
 [89323707037882956 89323707037886762]]

Most of the time, they are just going to be the within-chunk edges

split_row = change_log.loc[split_id]
x = split_row["timestamp"]

timestamp = datetime.datetime.fromtimestamp(x / 1000, pytz.UTC)
delta = datetime.timedelta(microseconds=1)

nodes_removed = np.unique(np.concatenate([list(edge) for edge in removed_edges]))

pre_l2_ids = client.chunkedgraph.get_roots(
    nodes_removed, stop_layer=2, timestamp=timestamp - delta
)
post_l2_ids = client.chunkedgraph.get_roots(
    nodes_removed, stop_layer=2, timestamp=timestamp + delta
)

pre_map = dict(zip(nodes_removed, pre_l2_ids))

removed_edges = np.array(removed_edges)
removed_l2_edges = np.vectorize(lambda x: pre_map[x])(removed_edges)
print("Level2 edges removed:", removed_l2_edges)
[[161381301075575086 161381301075575086]
 [161381301075575086 161381301075575086]
 [161381301075575086 161381301075575086]
 [161381301075575086 161381301075575086]]

i.e. every one of those supervoxel edges was just within the chunk, so everything that got removed doesn't really help us with mapping changes at the level2 graph. For comparison, this is the result I'm after:

print(networkdeltas[split_id].removed_edges.values)
[[161310932331397482 161381301075575086]
 [161381301075575086 161381369795051814]
 [161381301075575086 161381369795051931]
 [161381301075575086 161381369795052002]]

So, in summary, the first issue is that even at the supervoxel-edge level, there's missing information about which supervoxels get changed. The second issue is that even when more edges are written down in the log (splits), that's not enough information to reconstruct the changes for the level2 graph. This is because one level2 node was replaced by two nodes in the split, each of which inherited some of the edges from the original. I am wondering if anyone can think of a way to recover what those edges are without my hacky solution of going via the level2 graph pre- and post-edit, or if that is really the only path?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant