Skip to content

Commit

Permalink
add more safeguards
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh committed Sep 12, 2023
1 parent 068ad85 commit c069914
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions pychunkedgraph/graph/edits.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .utils.serializers import serialize_uint64
from ..logging.log_db import TimeIt
from ..utils.general import in2d
from ..debug.utils import get_l2children


def _init_old_hierarchy(cg, l2ids: np.ndarray, parent_ts: datetime.datetime = None):
Expand Down Expand Up @@ -184,7 +185,9 @@ def _update_neighbor_cross_edges_single(
layer_edges = cx_edges_d.get(layer, types.empty_2d)
counterparts.extend(layer_edges[:, 1])

cp_cx_edges_d = cg.get_cross_chunk_edges(counterparts, time_stamp=parent_ts)
cp_cx_edges_d = cg.get_cross_chunk_edges(
counterparts, time_stamp=parent_ts, raw_only=True
)
updated_counterparts = {}
for counterpart, edges_d in cp_cx_edges_d.items():
val_dict = {}
Expand All @@ -204,17 +207,22 @@ def _update_neighbor_cross_edges_single(


def _update_neighbor_cross_edges(
cg, new_ids: List[int], new_old_id_d: dict, *, time_stamp, parent_ts
cg, new_ids: List[int], new_old_id_d: dict, old_new_id_d, *, time_stamp, parent_ts
) -> List:
newid_cx_edges_d = cg.get_cross_chunk_edges(new_ids, time_stamp=parent_ts)
node_map = {}
for k, v in old_new_id_d.items():
node_map[k] = next(iter(v))

updated_counterparts = {}
newid_cx_edges_d = cg.get_cross_chunk_edges(new_ids, time_stamp=parent_ts)
for new_id in new_ids:
cx_edges_d = newid_cx_edges_d[new_id]
temp_map = {
old_id: new_id for old_id in _get_flipped_ids(new_old_id_d, [new_id])
}
node_map.update(temp_map)
result = _update_neighbor_cross_edges_single(
cg, new_id, cx_edges_d, temp_map, parent_ts=parent_ts
cg, new_id, cx_edges_d, node_map, parent_ts=parent_ts
)
updated_counterparts.update(result)

Expand Down Expand Up @@ -284,6 +292,7 @@ def add_edges(
cg,
new_l2_ids,
new_old_id_d,
old_new_id_d,
time_stamp=time_stamp,
parent_ts=parent_ts,
)
Expand All @@ -300,6 +309,9 @@ def add_edges(
)

new_roots = create_parents.run()
for new_root in new_roots:
l2c = get_l2children(cg, new_root)
assert len(l2c) == np.unique(l2c).size, f"inconsistent result op {operation_id}"
create_parents.create_new_entries()
return new_roots, new_l2_ids, updated_entries + create_parents.new_entries

Expand All @@ -318,13 +330,13 @@ def _process_l2_agglomeration(
chunk_edges = chunk_edges[~in2d(chunk_edges, removed_edges)]

cross_edges = agg.cross_edges.get_pairs()
# we must avoid the cache to read roots to get segment state before edit began
parents = cg.get_parents(cross_edges[:, 0], time_stamp=parent_ts, raw_only=True)
err = f"got cross edges from more than one l2 node; op {operation_id}"
assert np.unique(parents).size == 1, err
root = cg.get_root(parents[0], time_stamp=parent_ts, raw_only=True)

# inactive edges must be filtered out
# we must avoid the cache to read roots to get segment state before edit began
neighbor_roots = cg.get_roots(
cross_edges[:, 1], raw_only=True, time_stamp=parent_ts
)
Expand Down Expand Up @@ -416,6 +428,7 @@ def remove_edges(
cg,
new_l2_ids,
new_old_id_d,
old_new_id_d,
time_stamp=time_stamp,
parent_ts=parent_ts,
)
Expand All @@ -431,6 +444,9 @@ def remove_edges(
parent_ts=parent_ts,
)
new_roots = create_parents.run()
for new_root in new_roots:
l2c = get_l2children(cg, new_root)
assert len(l2c) == np.unique(l2c).size, f"inconsistent result op {operation_id}"
create_parents.create_new_entries()
return new_roots, new_l2_ids, updated_entries + create_parents.new_entries

Expand Down Expand Up @@ -478,6 +494,7 @@ def _update_id_lineage(
layer: int,
parent_layer: int,
):
# update newly created children; mask others
mask = np.in1d(children, self._new_ids_d[layer])
for child_id in children[mask]:
child_old_ids = self._new_old_id_d[child_id]
Expand Down Expand Up @@ -530,7 +547,7 @@ def _update_cross_edge_cache(self, parent, children):
cx_edges_d = self.cg.get_cross_chunk_edges(
children, time_stamp=self._last_successful_ts
)
cx_edges_d = concatenate_cross_edge_dicts(cx_edges_d.values(), unique=True)
cx_edges_d = concatenate_cross_edge_dicts(cx_edges_d.values())

parent_layer = self.cg.get_chunk_layer(parent)
edge_nodes = np.unique(np.concatenate([*cx_edges_d.values(), types.empty_2d]))
Expand Down Expand Up @@ -566,19 +583,9 @@ def _create_new_parents(self, layer: int):
layer_node_ids = self._get_layer_node_ids(new_ids, layer)
components, graph_ids = self._get_connected_components(layer_node_ids, layer)
new_parent_ids = []
all_old_ids = []
for v in self._new_old_id_d.values():
all_old_ids.extend(v)
all_old_ids = np.array(all_old_ids, dtype=basetypes.NODE_ID)

for cc_indices in components:
parent_layer = layer + 1 # must be reset for each connected component
cc_ids = graph_ids[cc_indices]
mask = np.isin(cc_ids, all_old_ids)
old_ids = cc_ids[mask]
new_ids = _get_flipped_ids(self._old_new_id_d, cc_ids[mask])
err = f"got old ids {old_ids} -> {new_ids}; op {self._operation_id}"
assert np.all(~mask), err
if len(cc_ids) == 1:
# skip connection
parent_layer = self.cg.meta.layer_count
Expand Down Expand Up @@ -610,6 +617,7 @@ def _create_new_parents(self, layer: int):
self.cg,
new_parent_ids,
self._new_old_id_d,
self._old_new_id_d,
time_stamp=self._time_stamp,
parent_ts=self._last_successful_ts,
)
Expand Down

0 comments on commit c069914

Please sign in to comment.