Skip to content

Commit

Permalink
Merge pull request #72 from seung-lab/validate_decorator
Browse files Browse the repository at this point in the history
Validate datastack decorator
  • Loading branch information
fcollman authored Aug 8, 2022
2 parents 827a2cc + 53fc12e commit 9f80c85
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 74 deletions.
74 changes: 22 additions & 52 deletions materializationengine/blueprints/client/api.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,25 @@
import logging
import os
import time

import pyarrow as pa
from cachetools import LRUCache, TTLCache, cached
from cloudfiles import compression
from dynamicannotationdb.models import AnalysisTable, AnalysisVersion
from emannotationschemas import get_schema
from emannotationschemas.models import (
Base,
create_table_dict,
make_annotation_model,
make_flat_model,
make_segmentation_model,
sqlalchemy_models,
)
from flask import Response, abort, current_app, request, stream_with_context
from flask_accepts import accepts, responds
from flask import Response, abort, current_app, request
from flask_accepts import accepts
from flask_restx import Namespace, Resource, inputs, reqparse
from materializationengine.blueprints.client.query import _execute_query, specific_query
from materializationengine.blueprints.client.query import specific_query
from materializationengine.blueprints.client.schemas import (
ComplexQuerySchema,
CreateTableSchema,
GetDeleteAnnotationSchema,
Metadata,
PostPutAnnotationSchema,
SegmentationDataSchema,
SegmentationTableSchema,
SimpleQuerySchema,
)
from materializationengine.blueprints.reset_auth import reset_auth
from materializationengine.database import (
create_session,
dynamic_annotation_cache,
sqlalchemy_cache,
)
from materializationengine.info_client import (
get_aligned_volumes,
get_datastack_info,
get_datastacks,
)
from materializationengine.database import dynamic_annotation_cache, sqlalchemy_cache
from materializationengine.info_client import get_aligned_volumes, get_datastack_info, get_relevant_datastack_info
from materializationengine.schemas import AnalysisTableSchema, AnalysisVersionSchema
from middle_auth_client import (
auth_required,
auth_requires_admin,
auth_requires_permission,
)
from sqlalchemy.engine.url import make_url

__version__ = "4.0.22"
from middle_auth_client import auth_requires_permission
from materializationengine.blueprints.client.datastack import validate_datastack
__version__ = "4.0.20"


authorizations = {
Expand Down Expand Up @@ -128,15 +99,6 @@ def check_aligned_volume(aligned_volume):
abort(400, f"aligned volume: {aligned_volume} not valid")


@cached(cache=TTLCache(maxsize=64, ttl=600))
def get_relevant_datastack_info(datastack_name):
ds_info = get_datastack_info(datastack_name=datastack_name)
seg_source = ds_info["segmentation_source"]
pcg_table_name = seg_source.split("/")[-1]
aligned_volume_name = ds_info["aligned_volume"]["name"]
return aligned_volume_name, pcg_table_name


@cached(cache=LRUCache(maxsize=64))
def get_analysis_version_and_table(
datastack_name: str, table_name: str, version: int, Session
Expand Down Expand Up @@ -241,6 +203,7 @@ def get_flat_model(datastack_name: str, table_name: str, version: int, Session):
class DatastackVersions(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("datastack_versions", security="apikey")
def get(self, datastack_name: str):
"""get available versions
Expand Down Expand Up @@ -271,6 +234,7 @@ def get(self, datastack_name: str):
class DatastackVersion(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("version metadata", security="apikey")
def get(self, datastack_name: str, version: int):
"""get version metadata
Expand Down Expand Up @@ -303,6 +267,7 @@ def get(self, datastack_name: str, version: int):
class DatastackMetadata(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("all valid version metadata", security="apikey")
def get(self, datastack_name: str):
"""get materialized metadata for all valid versions
Expand Down Expand Up @@ -331,6 +296,7 @@ def get(self, datastack_name: str):
class FrozenTableVersions(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("get_frozen_tables", security="apikey")
def get(self, datastack_name: str, version: int):
"""get frozen tables
Expand Down Expand Up @@ -373,6 +339,7 @@ def get(self, datastack_name: str, version: int):
class FrozenTableMetadata(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("get_frozen_table_metadata", security="apikey")
def get(self, datastack_name: str, version: int, table_name: str):
"""get frozen table metadata
Expand Down Expand Up @@ -410,6 +377,7 @@ def get(self, datastack_name: str, version: int, table_name: str):
class FrozenTableCount(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("simple_query", security="apikey")
def get(self, datastack_name: str, version: int, table_name: str):
"""get annotation count in table
Expand Down Expand Up @@ -438,6 +406,7 @@ def get(self, datastack_name: str, version: int, table_name: str):
class LiveTableQuery(Resource):
@reset_auth
@auth_requires_permission("admin_view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("live_simple_query", security="apikey")
@accepts("SimpleQuerySchema", schema=SimpleQuerySchema, api=client_bp)
def post(self, datastack_name: str, table_name: str):
Expand Down Expand Up @@ -583,6 +552,7 @@ def _format_filter(filter, table_in, seg_table):
class FrozenTableQuery(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("simple_query", security="apikey")
@accepts("SimpleQuerySchema", schema=SimpleQuerySchema, api=client_bp)
def post(self, datastack_name: str, version: int, table_name: str):
Expand Down Expand Up @@ -637,13 +607,11 @@ def post(self, datastack_name: str, version: int, table_name: str):
time_d["get Model"] = time.time() - now
now = time.time()

Session = sqlalchemy_cache.get("{}__mat{}".format(datastack_name, version))
Session = sqlalchemy_cache.get(f"{datastack_name}__mat{version}")
time_d["get Session"] = time.time() - now
now = time.time()

engine = sqlalchemy_cache.get_engine(
"{}__mat{}".format(datastack_name, version)
)
engine = sqlalchemy_cache.get_engine(f"{datastack_name}__mat{version}")
time_d["get engine"] = time.time() - now
now = time.time()
max_limit = current_app.config.get("QUERY_LIMIT_SIZE", 200000)
Expand Down Expand Up @@ -709,6 +677,7 @@ def post(self, datastack_name: str, version: int, table_name: str):
class FrozenQuery(Resource):
@reset_auth
@auth_requires_permission("view", table_arg="datastack_name")
@validate_datastack
@client_bp.doc("complex_query", security="apikey")
@accepts("ComplexQuerySchema", schema=ComplexQuerySchema, api=client_bp)
def post(self, datastack_name: str, version: int):
Expand Down Expand Up @@ -766,7 +735,7 @@ def post(self, datastack_name: str, version: int):
Model = get_flat_model(datastack_name, table_name, version, Session)
model_dict[table_name] = Model

db_name = "{}__mat{}".format(datastack_name, version)
db_name = f"{datastack_name}__mat{version}"
Session = sqlalchemy_cache.get(db_name)
engine = sqlalchemy_cache.get_engine(db_name)
max_limit = current_app.config.get("QUERY_LIMIT_SIZE", 200000)
Expand All @@ -775,7 +744,8 @@ def post(self, datastack_name: str, version: int):
limit = data.get("limit", max_limit)
if limit > max_limit:
limit = max_limit
logging.debug("query {}".format(data))
logging.debug(f"query {data}")

df = specific_query(
Session,
engine,
Expand Down
107 changes: 107 additions & 0 deletions materializationengine/blueprints/client/datastack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import inspect
import logging
from functools import wraps

from dynamicannotationdb.models import (
AnalysisTable,
AnalysisVersion,
)
from materializationengine.database import sqlalchemy_cache

from materializationengine.info_client import get_relevant_datastack_info


def parse_args(f, *args, **kwargs):
sig = inspect.signature(f).bind(*args, **kwargs)
sig.apply_defaults()
return dict(sig.arguments)


def validate_datastack(f):
@wraps(f)
def wrapper(*args, **kwargs):
# get args and kwargs as a dict
arguments = parse_args(f, *args, **kwargs)

target_table = arguments.get("table_name")
target_datastack = arguments.get("datastack_name")
target_version = arguments.get("version")
aligned_volume_name, pcg_table_name = get_relevant_datastack_info(
target_datastack
)
session = sqlalchemy_cache.get(aligned_volume_name)
version_query = session.query(AnalysisVersion).filter(
AnalysisVersion.datastack == target_datastack
)
if target_version:
version_query = version_query.filter(
AnalysisVersion.version == target_version
)
try:
versions = version_query.all()
except Exception as e:
logging.error(e)
session.rollback()
versions = None

# nothing to do here
if not versions:
return f(*args, **kwargs)

parent_version_info = []
for version in versions:

if version.parent_version:
parent_version_info.append(
session.query(AnalysisVersion)
.filter(AnalysisVersion.id == version.parent_version)
.filter(AnalysisVersion.valid == True)
.all()
)

# no need to map to a parent version
if not any(parent_version_info):
return f(*args, **kwargs)
else:
parent_version = str(parent_version_info[0][0])
# validate all parent versions
valid_versions = []
for version_info in parent_version_info:
info = version_info[0]
if info.valid:
valid_versions.append(info.version)

if target_table and target_version:
# confirm target version is valid
if target_version in valid_versions:
analysis_version = (
session.query(AnalysisVersion)
.filter(AnalysisVersion.version == target_version)
.filter(AnalysisVersion.datastack == target_datastack)
.first()
)
if analysis_version is None:
return None, 404
response = (
session.query(AnalysisTable)
.filter(AnalysisTable.analysisversion_id == analysis_version.id)
.filter(AnalysisTable.valid == True)
.all()
)
# check if target table is valid
valid_tables = [r.table_name for r in response]
if target_table not in valid_tables:
raise ValueError(
f"{target_table} not valid for version {target_version}"
)
# remap datastack name to point to parent version
if kwargs.get("datastack_name"):
kwargs["datastack_name"] = parent_version
else:
args_list = list(args)
args_list[0] = parent_version
new_args = tuple(args_list)
return f(*new_args, **kwargs)
return f(*args, **kwargs)

return wrapper
27 changes: 18 additions & 9 deletions materializationengine/info_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from materializationengine.errors import (
AlignedVolumeNotFoundException,
DataStackNotFoundException,
)
from materializationengine.utils import get_config_param
from flask import current_app
import requests
import os
from caveclient.infoservice import InfoServiceClient
from caveclient.auth import AuthClient

import cachetools.func
import requests
from cachetools import LRUCache, TTLCache, cached
from caveclient.auth import AuthClient
from caveclient.infoservice import InfoServiceClient
from flask import current_app

from materializationengine.errors import (AlignedVolumeNotFoundException,
DataStackNotFoundException)
from materializationengine.utils import get_config_param


@cachetools.func.ttl_cache(maxsize=2, ttl=5 * 60)
Expand Down Expand Up @@ -70,3 +71,11 @@ def get_datastack_info(datastack_name):
raise DataStackNotFoundException(
f"datastack {datastack_name} info not returned"
)

@cached(cache=TTLCache(maxsize=64, ttl=600))
def get_relevant_datastack_info(datastack_name):
ds_info = get_datastack_info(datastack_name=datastack_name)
seg_source = ds_info["segmentation_source"]
pcg_table_name = seg_source.split("/")[-1]
aligned_volume_name = ds_info["aligned_volume"]["name"]
return aligned_volume_name, pcg_table_name
14 changes: 5 additions & 9 deletions materializationengine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from materializationengine.blueprints.reset_auth import reset_auth
from materializationengine.celery_init import celery
from materializationengine.database import sqlalchemy_cache
from materializationengine.info_client import get_datastack_info, get_datastacks
from materializationengine.info_client import (
get_datastack_info,
get_datastacks,
get_relevant_datastack_info,
)
from materializationengine.schemas import AnalysisTableSchema, AnalysisVersionSchema

__version__ = "4.0.22"
Expand Down Expand Up @@ -92,14 +96,6 @@ def make_df_with_links_to_id(objects, schema, url, col, **urlkwargs):
return df


def get_relevant_datastack_info(datastack_name):
ds_info = get_datastack_info(datastack_name=datastack_name)
seg_source = ds_info["segmentation_source"]
pcg_table_name = seg_source.split("/")[-1]
aligned_volume_name = ds_info["aligned_volume"]["name"]
return aligned_volume_name, pcg_table_name


@views_bp.route("/datastack/<datastack_name>")
@auth_requires_permission("view", table_arg="datastack_name")
def datastack_view(datastack_name):
Expand Down
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ gcsfs>=0.8.0
pyarrow==3.0.0
flask_cors
numpy>=1.20
dynamicannotationdb>=5.0.7
dynamicannotationdb>=5.1.0
emannotationschemas>=5.0.2
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ dill==0.3.4
# pathos
dracopy==1.0.1
# via cloud-volume
dynamicannotationdb==5.0.7
dynamicannotationdb==5.1.0
# via -r requirements.in
emannotationschemas==5.0.2
# via
Expand Down Expand Up @@ -514,9 +514,9 @@ zipp==3.7.0
# via
# importlib-metadata
# importlib-resources
zope-event==4.5.0
zope.event==4.5.0
# via gevent
zope-interface==5.4.0
zope.interface==5.4.0
# via gevent
zstandard==0.17.0
# via cloud-files
Expand Down

0 comments on commit 9f80c85

Please sign in to comment.