Skip to content

Commit

Permalink
use parent timestamps to lift cx edges
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh committed Sep 22, 2024
1 parent 77947f1 commit 98c91d0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 68 deletions.
72 changes: 15 additions & 57 deletions pychunkedgraph/ingest/upgrade/atomic_layer.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,19 @@
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member

from datetime import timedelta

import fastremap
import numpy as np
from pychunkedgraph.graph import ChunkedGraph
from pychunkedgraph.graph.attributes import Connectivity
from pychunkedgraph.graph.attributes import Hierarchy
from pychunkedgraph.graph.utils import serializers

from .utils import exists_as_parent


def get_parent_timestamps(cg, supervoxels, start_time=None, end_time=None) -> set:
"""
Timestamps of when the given supervoxels were edited, in the given time range.
"""
response = cg.client.read_nodes(
node_ids=supervoxels,
start_time=start_time,
end_time=end_time,
end_time_inclusive=False,
)
result = set()
for v in response.values():
for cell in v[Hierarchy.Parent]:
valid = cell.timestamp >= start_time or cell.timestamp < end_time
assert valid, f"{cell.timestamp}, {start_time}"
result.add(cell.timestamp)
return result
from .utils import exists_as_parent, get_parent_timestamps


def get_edit_timestamps(cg: ChunkedGraph, edges_d, start_ts, end_ts) -> list:
"""
Timestamps of when post-side supervoxels were involved in an edit.
Post-side - supervoxels in the neighbor chunk.
This is required because we need to update edges from both sides.
"""
atomic_cx_edges = np.concatenate(list(edges_d.values()))
timestamps = get_parent_timestamps(
cg, atomic_cx_edges[:, 1], start_time=start_ts, end_time=end_ts
)
timestamps.add(start_ts)
return sorted(timestamps)


def update_cross_edges(cg: ChunkedGraph, node, cx_edges_d, node_ts, end_ts) -> list:
def update_cross_edges(
cg: ChunkedGraph, node, cx_edges_d, node_ts, timestamps, earliest_ts
) -> list:
"""
Helper function to update a single L2 ID.
Returns a list of mutations with given timestamps.
Expand All @@ -58,10 +27,9 @@ def update_cross_edges(cg: ChunkedGraph, node, cx_edges_d, node_ts, end_ts) -> l
assert not exists_as_parent(cg, node, edges[:, 0])
return rows

timestamps = [node_ts]
if node_ts != end_ts:
timestamps = get_edit_timestamps(cg, cx_edges_d, node_ts, end_ts)
for ts in timestamps:
if ts < earliest_ts:
ts = earliest_ts
val_dict = {}
svs = edges[:, 1]
parents = cg.get_parents(svs, time_stamp=ts)
Expand All @@ -80,31 +48,21 @@ def update_cross_edges(cg: ChunkedGraph, node, cx_edges_d, node_ts, end_ts) -> l


def update_nodes(cg: ChunkedGraph, nodes) -> list:
# get start_ts when node becomes valid
nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True)
earliest_ts = cg.get_earliest_timestamp()
timestamps_d = get_parent_timestamps(cg, nodes)
cx_edges_d = cg.get_atomic_cross_edges(nodes)
children_d = cg.get_children(nodes)

rows = []
for node, start_ts in zip(nodes, nodes_ts):
for node, node_ts in zip(nodes, nodes_ts):
if cg.get_parent(node) is None:
# invalid id caused by failed ingest task
continue
node_cx_edges_d = cx_edges_d.get(node, {})
if not node_cx_edges_d:
_cx_edges_d = cx_edges_d.get(node, {})
if not _cx_edges_d:
continue

# get end_ts when node becomes invalid (bigtable resolution is in ms)
start = start_ts + timedelta(milliseconds=1)
_timestamps = get_parent_timestamps(cg, children_d[node], start_time=start)
try:
end_ts = sorted(_timestamps)[0]
except IndexError:
# start_ts == end_ts means there has been no edit involving this node
# meaning only one timestamp to update cross edges, start_ts
end_ts = start_ts
# for each timestamp until end_ts, update cross chunk edges of node
_rows = update_cross_edges(cg, node, node_cx_edges_d, start_ts, end_ts)
_rows = update_cross_edges(
cg, node, _cx_edges_d, node_ts, timestamps_d[node], earliest_ts
)
rows.extend(_rows)
return rows

Expand Down
20 changes: 9 additions & 11 deletions pychunkedgraph/ingest/upgrade/parent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pychunkedgraph.graph.types import empty_2d
from pychunkedgraph.utils.general import chunked

from .utils import exists_as_parent
from .utils import exists_as_parent, get_parent_timestamps


CHILDREN = {}
Expand Down Expand Up @@ -50,7 +50,7 @@ def _get_cx_edges_at_timestamp(node, response, ts):


def _populate_cx_edges_with_timestamps(
cg: ChunkedGraph, layer: int, nodes: list, nodes_ts: list
cg: ChunkedGraph, layer: int, nodes: list, earliest_ts
):
"""
Collect timestamps of edits from children, since we use the same timestamp
Expand All @@ -61,15 +61,13 @@ def _populate_cx_edges_with_timestamps(
attrs = [Connectivity.CrossChunkEdge[l] for l in range(layer, cg.meta.layer_count)]
all_children = np.concatenate(list(CHILDREN.values()))
response = cg.client.read_nodes(node_ids=all_children, properties=attrs)
for node, node_ts in zip(nodes, nodes_ts):
timestamps = set([node_ts])
for child in CHILDREN[node]:
if child not in response:
continue
for cells in response[child].values():
timestamps.update([c.timestamp for c in cells if c.timestamp > node_ts])
timestamps_d = get_parent_timestamps(cg, nodes)
for node in nodes:
CX_EDGES[node] = {}
timestamps = timestamps_d[node]
for ts in sorted(timestamps):
if ts < earliest_ts:
ts = earliest_ts
CX_EDGES[node][ts] = _get_cx_edges_at_timestamp(node, response, ts)


Expand Down Expand Up @@ -142,19 +140,19 @@ def update_chunk(
start = time.time()
x, y, z = chunk_coords
chunk_id = cg.get_chunk_id(layer=layer, x=x, y=y, z=z)
earliest_ts = cg.get_earliest_timestamp()
_populate_nodes_and_children(cg, chunk_id, nodes=nodes)
if not CHILDREN:
return
nodes = list(CHILDREN.keys())
random.shuffle(nodes)
nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True)
_populate_cx_edges_with_timestamps(cg, layer, nodes, nodes_ts)
_populate_cx_edges_with_timestamps(cg, layer, nodes, earliest_ts)

task_size = int(math.ceil(len(nodes) / mp.cpu_count() / 2))
chunked_nodes = chunked(nodes, task_size)
chunked_nodes_ts = chunked(nodes_ts, task_size)
cg_info = cg.get_serialized_info()
earliest_ts = cg.get_earliest_timestamp()

multi_args = []
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
Expand Down
50 changes: 50 additions & 0 deletions pychunkedgraph/ingest/upgrade/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# pylint: disable=invalid-name, missing-docstring

from collections import defaultdict
from datetime import timedelta

import numpy as np
from pychunkedgraph.graph import ChunkedGraph
from pychunkedgraph.graph.attributes import Hierarchy

Expand All @@ -11,3 +17,47 @@ def exists_as_parent(cg: ChunkedGraph, parent, nodes) -> bool:
for cells in response.values():
parents.update([cell.value for cell in cells])
return parent in parents


def get_edit_timestamps(cg: ChunkedGraph, edges_d, start_ts, end_ts) -> list:
"""
Timestamps of when post-side nodes were involved in an edit.
Post-side - nodes in the neighbor chunk.
This is required because we need to update edges from both sides.
"""
cx_edges = np.concatenate(list(edges_d.values()))
timestamps = get_parent_timestamps(
cg, cx_edges[:, 1], start_time=start_ts, end_time=end_ts
)
timestamps.add(start_ts)
return sorted(timestamps)


def get_end_ts(cg: ChunkedGraph, children, start_ts):
# get end_ts when node becomes invalid (bigtable resolution is in ms)
start = start_ts + timedelta(milliseconds=1)
_timestamps = get_parent_timestamps(cg, children, start_time=start)
try:
end_ts = sorted(_timestamps)[0]
except IndexError:
# start_ts == end_ts means there has been no edit involving this node
# meaning only one timestamp to update cross edges, start_ts
end_ts = start_ts
return end_ts


def get_parent_timestamps(cg: ChunkedGraph, nodes) -> dict[int, set]:
"""
Timestamps of when the given nodes were edited.
"""
response = cg.client.read_nodes(
node_ids=nodes,
properties=[Hierarchy.Parent],
end_time_inclusive=False,
)

result = defaultdict(set)
for k, v in response.items():
for cell in v[Hierarchy.Parent]:
result[k].add(cell.timestamp)
return result

0 comments on commit 98c91d0

Please sign in to comment.