Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
bdpedigo committed Mar 11, 2024
1 parent ca3502c commit 4c424a6
Showing 1 changed file with 33 additions and 47 deletions.
80 changes: 33 additions & 47 deletions pychunkedgraph/app/segmentation/common.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
# pylint: disable=invalid-name, missing-docstring

import json
import time
import os
import time
from datetime import datetime
from functools import reduce

import numpy as np
from pytz import UTC
import pandas as pd

from flask import current_app, g, jsonify, make_response, request
from pytz import UTC

from pychunkedgraph import __version__
from pychunkedgraph.app import app_utils
from pychunkedgraph.graph import (
attributes,
cutting,
exceptions as cg_exceptions,
segmenthistory,
)
from pychunkedgraph.graph import (
edges as cg_edges,
)
from pychunkedgraph.graph import segmenthistory
from pychunkedgraph.graph.utils import basetypes
from pychunkedgraph.graph import (
exceptions as cg_exceptions,
)
from pychunkedgraph.graph.analysis import pathing
from pychunkedgraph.graph.attributes import OperationLogs
from pychunkedgraph.meshing import mesh_analysis
from pychunkedgraph.graph.misc import get_contact_sites
from pychunkedgraph.graph.operation import GraphEditOperation

from pychunkedgraph.graph.utils import basetypes
from pychunkedgraph.meshing import mesh_analysis

__api_versions__ = [0, 1]
__segmentation_url_prefix__ = os.environ.get("SEGMENTATION_URL_PREFIX", "segmentation")
Expand Down Expand Up @@ -72,6 +75,15 @@ def _parse_timestamp(
)


def _get_bounds_from_request(request):
if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
return bounding_box


# -------------------
# ------ Applications
# -------------------
Expand All @@ -93,9 +105,7 @@ def handle_info(table_id):
combined_info["verify_mesh"] = cg.meta.custom_data.get("mesh", {}).get(
"verify", False
)
mesh_dir = cg.meta.custom_data.get("mesh", {}).get(
"dir", None
)
mesh_dir = cg.meta.custom_data.get("mesh", {}).get("dir", None)
if mesh_dir is not None:
combined_info["mesh_dir"] = mesh_dir
elif combined_info.get("mesh_dir", None) is not None:
Expand Down Expand Up @@ -216,7 +226,7 @@ def publish_edit(
table_id: str, user_id: str, result: GraphEditOperation.Result, is_priority=True
):
import pickle
from os import getenv

from messagingclient import MessagingClient

attributes = {
Expand Down Expand Up @@ -454,7 +464,7 @@ def handle_rollback(table_id):
continue
try:
ret = cg.undo_operation(user_id=target_user_id, operation_id=operation_id)
except cg_exceptions.LockingError as e:
except cg_exceptions.LockingError:
raise cg_exceptions.InternalServerError(
"Could not acquire root lock for undo operation."
)
Expand Down Expand Up @@ -506,14 +516,14 @@ def all_user_operations(
user_id = entry[OperationLogs.UserID]

should_check = (
not OperationLogs.Status in entry
OperationLogs.Status not in entry
or entry[OperationLogs.Status] == OperationLogs.StatusCodes.SUCCESS.value
)

split_valid = (
include_partial_splits
or (OperationLogs.AddedEdge in entry)
or (not OperationLogs.RootID in entry)
or (OperationLogs.RootID not in entry)
or (len(entry[OperationLogs.RootID]) > 1)
)
if not split_valid:
Expand Down Expand Up @@ -589,15 +599,11 @@ def handle_leaves(table_id, root_id):
user_id = str(g.auth_user.get("id", current_app.user_id))

stop_layer = int(request.args.get("stop_layer", 1))
bounding_box = None
if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T

bounding_box = _get_bounds_from_request(request)

cg = app_utils.get_cg(table_id)
if stop_layer > 1:
from pychunkedgraph.graph.types import empty_1d

subgraph = cg.get_subgraph_nodes(
int(root_id),
bbox=bounding_box,
Expand All @@ -621,11 +627,7 @@ def handle_leaves_many(table_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

node_ids = np.array(json.loads(request.data)["node_ids"], dtype=np.uint64)
stop_layer = int(request.args.get("stop_layer", 1))
Expand All @@ -652,11 +654,7 @@ def handle_leaves_from_leave(table_id, atomic_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

# Call ChunkedGraph
cg = app_utils.get_cg(table_id)
Expand All @@ -676,11 +674,7 @@ def handle_subgraph(table_id, root_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

# Call ChunkedGraph
cg = app_utils.get_cg(table_id)
Expand Down Expand Up @@ -820,6 +814,7 @@ def merge_log(table_id, root_id):

def handle_lineage_graph(table_id, root_id=None):
from networkx import node_link_data

from pychunkedgraph.graph.lineage import lineage_graph

current_app.table_id = table_id
Expand Down Expand Up @@ -890,11 +885,7 @@ def handle_contact_sites(table_id, root_id):

timestamp = _parse_timestamp("timestamp", time.time(), return_datetime=True)

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

# Call ChunkedGraph
cg = app_utils.get_cg(table_id)
Expand Down Expand Up @@ -1042,11 +1033,7 @@ def handle_get_layer2_graph(table_id, node_id):
current_app.table_id = table_id
user_id = str(g.auth_user.get("id", current_app.user_id))

if "bounds" in request.args:
bounds = request.args["bounds"]
bounding_box = np.array([b.split("-") for b in bounds.split("_")], dtype=int).T
else:
bounding_box = None
bounding_box = _get_bounds_from_request(request)

cg = app_utils.get_cg(table_id)
print("Finding edge graph...")
Expand Down Expand Up @@ -1095,7 +1082,6 @@ def handle_root_timestamps(table_id, is_binary):


def operation_details(table_id):
from pychunkedgraph.graph import attributes
from pychunkedgraph.export.operation_logs import parse_attr

current_app.table_id = table_id
Expand Down

0 comments on commit 4c424a6

Please sign in to comment.