Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: granular subgraph logs #455

Merged
merged 21 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions pychunkedgraph/graph/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,28 +687,18 @@ def get_l2_agglomerations(

l2id_children_d = self.get_children(level2_ids)
sv_parent_d = {}
supervoxels = []
for l2id in l2id_children_d:
svs = l2id_children_d[l2id]
sv_parent_d.update(dict(zip(svs.tolist(), [l2id] * len(svs))))
supervoxels.append(svs)

supervoxels = np.concatenate(supervoxels)

def f(x):
return sv_parent_d.get(x, x)

get_sv_parents = np.vectorize(f, otypes=[np.uint64])
in_edges, out_edges, cross_edges = edge_utils.categorize_edges_v2(
self.meta,
supervoxels,
all_chunk_edges,
l2id_children_d,
get_sv_parents,
sv_parent_d
)

agglomeration_d = get_agglomerations(
l2id_children_d, in_edges, out_edges, cross_edges, get_sv_parents
l2id_children_d, in_edges, out_edges, cross_edges, sv_parent_d
)
return (
agglomeration_d,
Expand Down
29 changes: 9 additions & 20 deletions pychunkedgraph/graph/client/bigtable/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,6 @@ def _read(
# calculate this properly (range_read.request.SerializeToString()), but this estimate is
# good enough for now

from pychunkedgraph.logging.log_db import TimeIt

n_subrequests = max(
1, int(np.ceil(len(row_set.row_keys) / self._max_row_key_count))
)
Expand All @@ -849,23 +847,14 @@ def _read(

# Don't forget the original RowSet's row_ranges
row_sets[0].row_ranges = row_set.row_ranges

with TimeIt(
"chunked_reads",
f"{self._table.table_id}_bt_profile",
operation_id=-1,
n_rows=len(row_set.row_keys),
n_requests=n_subrequests,
responses = mu.multithread_func(
self._execute_read_thread,
params=((self._table, r, row_filter) for r in row_sets),
debug=n_threads == 1,
n_threads=n_threads,
):
responses = mu.multithread_func(
self._execute_read_thread,
params=((self._table, r, row_filter) for r in row_sets),
debug=n_threads == 1,
n_threads=n_threads,
)
)

combined_response = {}
for resp in responses:
combined_response.update(resp)
return combined_response
combined_response = {}
for resp in responses:
combined_response.update(resp)
return combined_response
26 changes: 16 additions & 10 deletions pychunkedgraph/graph/edges/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member

"""
helper functions for edge stuff
"""

from typing import Dict
from typing import Tuple
from typing import Iterable
from typing import Callable
from typing import Optional

import fastremap
import numpy as np

from . import Edges
Expand Down Expand Up @@ -90,8 +92,6 @@ def categorize_edges(
in_mask = mask1 & mask2
out_mask = mask1 & ~mask2

print("np.sum(in_mask)", np.sum(in_mask))

in_edges = edges[in_mask]
all_out_edges = edges[out_mask] # out_edges + cross_edges

Expand All @@ -104,20 +104,26 @@ def categorize_edges(

def categorize_edges_v2(
meta: ChunkedGraphMeta,
supervoxels: np.ndarray,
edges: Edges,
l2id_children_d: Dict,
get_sv_parents: Callable,
sv_parent_d: Dict,
) -> Tuple[Edges, Edges, Edges]:
"""Faster version of categorize_edges(), avoids looping over L2 IDs."""
node_ids1 = get_sv_parents(edges.node_ids1)
node_ids2 = get_sv_parents(edges.node_ids2)

node_ids1 = fastremap.remap(
edges.node_ids1, sv_parent_d, preserve_missing_labels=True
)
node_ids2 = fastremap.remap(
edges.node_ids2, sv_parent_d, preserve_missing_labels=True
)

layer_mask1 = chunk_utils.get_chunk_layers(meta, node_ids1) > 1
in_edges = edges[node_ids1 == node_ids2]
all_out_ = edges[layer_mask1 & (node_ids1 != node_ids2)]
nodes_mask = node_ids1 == node_ids2

in_edges = edges[nodes_mask]
all_out_ = edges[layer_mask1 & ~nodes_mask]

cx_layers = get_cross_chunk_edges_layer(meta, all_out_.get_pairs())

cx_mask = cx_layers > 1
out_edges = all_out_[~cx_mask]
cross_edges = all_out_[cx_mask]
Expand Down
8 changes: 2 additions & 6 deletions pychunkedgraph/graph/edits.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def add_edges(
operation_id=operation_id,
time_stamp=time_stamp,
parent_ts=parent_ts,
edit_type="merge.apply.add_edges",
)

new_roots = create_parents.run()
Expand Down Expand Up @@ -335,7 +334,6 @@ def remove_edges(
operation_id=operation_id,
time_stamp=time_stamp,
parent_ts=parent_ts,
edit_type="split.apply.remove_edges",
)
new_roots = create_parents.run()
new_entries = create_parents.create_new_entries()
Expand All @@ -354,7 +352,6 @@ def __init__(
old_new_id_d: Dict[np.uint64, Iterable[np.uint64]] = None,
old_hierarchy_d: Dict[np.uint64, Dict[int, np.uint64]] = None,
parent_ts: datetime.datetime = None,
edit_type: str = None,
):
self.cg = cg
self._new_l2_ids = new_l2_ids
Expand All @@ -366,7 +363,6 @@ def __init__(
self._operation_id = operation_id
self._time_stamp = time_stamp
self._last_successful_ts = parent_ts
self._edit_type = edit_type

def _update_id_lineage(
self,
Expand Down Expand Up @@ -410,7 +406,7 @@ def _get_connected_components(
not_cached = _node_ids[~np.in1d(_node_ids, cached)]

with TimeIt(
f"{self._edit_type}.get_cross_chunk_edges.{layer}",
f"get_cross_chunk_edges.{layer}",
self.cg.graph_id,
self._operation_id,
):
Expand Down Expand Up @@ -517,7 +513,7 @@ def run(self) -> Iterable:
if len(self._new_ids_d[layer]) == 0:
continue
with TimeIt(
f"{self._edit_type}.create_new_parents_layer.{layer}",
f"create_new_parents_layer.{layer}",
self.cg.graph_id,
self._operation_id,
):
Expand Down
24 changes: 11 additions & 13 deletions pychunkedgraph/graph/misc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# TODO categorize these
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member, import-outside-toplevel

import numpy as np
import datetime
import collections
from typing import Dict
from typing import Callable
from typing import Optional
from typing import Sequence

import fastremap
import numpy as np
from multiwrapper import multiprocessing_utils as mu

from . import ChunkedGraph
Expand Down Expand Up @@ -136,8 +136,6 @@ def get_delta_roots(
cg: ChunkedGraph,
time_stamp_start: datetime.datetime,
time_stamp_end: Optional[datetime.datetime] = None,
min_seg_id: int = 1,
n_threads: int = 1,
) -> Sequence[np.uint64]:
# Create filters: time and id range
start_id = np.uint64(cg.get_chunk_id(layer=cg.meta.layer_count) + 1)
Expand Down Expand Up @@ -229,21 +227,21 @@ def get_contact_sites(
def get_agglomerations(
l2id_children_d: Dict,
in_edges: Edges,
out_edges: Edges,
cross_edges: Edges,
get_sv_parents: Callable,
ot_edges: Edges,
cx_edges: Edges,
sv_parent_d: Dict,
) -> Dict[np.uint64, Agglomeration]:
l2id_agglomeration_d = {}
_in = get_sv_parents(in_edges.node_ids1)
_out = get_sv_parents(out_edges.node_ids1)
_cross = get_sv_parents(cross_edges.node_ids1)
_in = fastremap.remap(in_edges.node_ids1, sv_parent_d, preserve_missing_labels=True)
_ot = fastremap.remap(ot_edges.node_ids1, sv_parent_d, preserve_missing_labels=True)
_cx = fastremap.remap(cx_edges.node_ids1, sv_parent_d, preserve_missing_labels=True)
for l2id in l2id_children_d:
l2id_agglomeration_d[l2id] = Agglomeration(
l2id,
l2id_children_d[l2id],
in_edges[_in == l2id],
out_edges[_out == l2id],
cross_edges[_cross == l2id],
ot_edges[_ot == l2id],
cx_edges[_cx == l2id],
)
return l2id_agglomeration_d

Expand Down
19 changes: 10 additions & 9 deletions pychunkedgraph/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def execute(
privileged_mode=self.privileged_mode,
) as lock:
self.cg.cache = CacheService(self.cg)
self.cg.meta.custom_data["operation_id"] = operation_id
timestamp = self.cg.client.get_consolidated_lock_timestamp(
lock.locked_root_ids,
np.array([lock.operation_id] * len(lock.locked_root_ids)),
Expand Down Expand Up @@ -603,15 +604,15 @@ def _apply(
if len(root_ids) < 2 and not self.allow_same_segment_merge:
raise PreconditionError("Supervoxels must belong to different objects.")
bbox = get_bbox(self.source_coords, self.sink_coords, self.bbox_offset)
with TimeIt("merge.apply.subgraph", self.cg.graph_id, operation_id):
with TimeIt("subgraph", self.cg.graph_id, operation_id):
edges = self.cg.get_subgraph(
root_ids,
bbox=bbox,
bbox_is_coordinate=True,
edges_only=True,
)

with TimeIt("merge.apply.preprocess", self.cg.graph_id, operation_id):
with TimeIt("preprocess", self.cg.graph_id, operation_id):
inactive_edges = edits.merge_preprocess(
self.cg,
subgraph_edges=edges,
Expand All @@ -626,7 +627,7 @@ def _apply(
time_stamp=timestamp,
parent_ts=self.parent_ts,
)
with TimeIt("merge.apply.add_edges", self.cg.graph_id, operation_id):
with TimeIt("add_edges", self.cg.graph_id, operation_id):
new_roots, new_l2_ids, new_entries = edits.add_edges(
self.cg,
atomic_edges=atomic_edges,
Expand Down Expand Up @@ -743,13 +744,13 @@ def _apply(
):
raise PreconditionError("Supervoxels must belong to the same object.")

with TimeIt("split.apply.subgraph", self.cg.graph_id, operation_id):
with TimeIt("subgraph", self.cg.graph_id, operation_id):
l2id_agglomeration_d, _ = self.cg.get_l2_agglomerations(
self.cg.get_parents(
self.removed_edges.ravel(), time_stamp=self.parent_ts
)
),
)
with TimeIt("split.apply.remove_edges", self.cg.graph_id, operation_id):
with TimeIt("remove_edges", self.cg.graph_id, operation_id):
return edits.remove_edges(
self.cg,
operation_id=operation_id,
Expand Down Expand Up @@ -890,7 +891,7 @@ def _apply(
self.sink_coords,
self.cg.meta.split_bounding_offset,
)
with TimeIt("split.apply.get_subgraph", self.cg.graph_id, operation_id):
with TimeIt("get_subgraph", self.cg.graph_id, operation_id):
l2id_agglomeration_d, edges = self.cg.get_subgraph(
root_ids.pop(), bbox=bbox, bbox_is_coordinate=True
)
Expand All @@ -905,7 +906,7 @@ def _apply(
if len(edges) == 0:
raise PreconditionError("No local edges found.")

with TimeIt("split.apply.multicut", self.cg.graph_id, operation_id):
with TimeIt("multicut", self.cg.graph_id, operation_id):
self.removed_edges = run_multicut(
edges,
self.source_ids,
Expand All @@ -916,7 +917,7 @@ def _apply(
if not self.removed_edges.size:
raise PostconditionError("Mincut could not find any edges to remove.")

with TimeIt("split.apply.remove_edges", self.cg.graph_id, operation_id):
with TimeIt("remove_edges", self.cg.graph_id, operation_id):
return edits.remove_edges(
self.cg,
operation_id=operation_id,
Expand Down
15 changes: 10 additions & 5 deletions pychunkedgraph/logging/log_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,17 @@ def get_log_db(graph_id: str) -> LogDB:


class TimeIt:
def __init__(self, name: str, graph_id: str, operation_id, **kwargs):
self._name = name
names = []
operation_id = -1

def __init__(self, name: str, graph_id: str, operation_id=-1, **kwargs):
self.names.append(name)
self._start = None
self._graph_id = graph_id
self._operation_id = int(operation_id)
self._ts = datetime.utcnow()
self._kwargs = kwargs
if operation_id != -1:
self.operation_id = operation_id

def __enter__(self):
self._start = time.time()
Expand All @@ -121,11 +125,12 @@ def __exit__(self, *args):
try:
log_db = get_log_db(self._graph_id)
log_db.log_code_block(
name=self._name,
operation_id=self._operation_id,
name=".".join(self.names),
operation_id=self.operation_id,
timestamp=self._ts,
time_ms=time_ms,
**self._kwargs,
)
except GoogleAPIError:
...
self.names.pop()
Loading