diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aae0b5d..072fd6a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,12 +18,12 @@ repos: - id: pretty-format-json args: ['--autofix', '--no-sort-keys'] - repo: https://github.com/ambv/black - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black language_version: python3.11 - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.10.0' + rev: 'v1.11.1' hooks: - id: mypy name: mypy @@ -36,12 +36,12 @@ repos: exclude: tests/ args: [--select, "D101,D102,D103,D105,D106"] - repo: https://github.com/PyCQA/bandit - rev: '1.7.8' + rev: '1.7.9' hooks: - id: bandit args: [--skip, "B101,B303,B110,B311"] - repo: https://github.com/PyCQA/flake8 - rev: '7.0.0' + rev: '7.1.1' hooks: - id: flake8 - repo: https://github.com/myint/autoflake diff --git a/apps/graph/main.py b/apps/graph/main.py new file mode 100644 index 0000000..18c414f --- /dev/null +++ b/apps/graph/main.py @@ -0,0 +1,26 @@ +import os + +from dotenv import load_dotenv +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates + +load_dotenv() + +app = FastAPI() + +templates = Jinja2Templates(directory="templates") +app.mount("/static", StaticFiles(directory="static"), name="static") + +EE_API_URL = os.getenv("EE_API_URL") + + +@app.get("/", response_class=HTMLResponse) +async def get_index(request: Request) -> HTMLResponse: + """ + Render the index page. + """ + return templates.TemplateResponse( + "index.html", {"request": request, "ee_api_url": EE_API_URL} + ) diff --git a/apps/graph/requirements.txt b/apps/graph/requirements.txt new file mode 100644 index 0000000..f6a8f08 --- /dev/null +++ b/apps/graph/requirements.txt @@ -0,0 +1,3 @@ +fastapi[all]==0.105.0 +uvicorn==0.20.0 +python-dotenv==0.21.1 diff --git a/apps/graph/static/script.js b/apps/graph/static/script.js new file mode 100644 index 0000000..d9577e4 --- /dev/null +++ b/apps/graph/static/script.js @@ -0,0 +1,200 @@ +// script.js + +async function loadRecommendations() { + const response = await fetch(`${eeApiUrl}/recommendation/list`); + const recommendations = await response.json(); + const recommendationList = document.getElementById('recommendation-list'); + recommendationList.innerHTML = ''; + recommendations.forEach(rec => { + const div = document.createElement('div'); + div.className = 'recommendation-item'; + div.innerHTML = ` +
${rec.recommendation_id}: ${rec.recommendation_name}
+
Version: ${rec.recommendation_version}
+
Package Version: ${rec.recommendation_package_version}
+ `; + div.onclick = () => loadGraph(rec.recommendation_id); + recommendationList.appendChild(div); + }); +} + +async function loadGraph(recommendationId) { + const response = await fetch(`${eeApiUrl}/recommendation/${recommendationId}/execution_graph`); + const data = await response.json(); + const graphData = data.recommendation_execution_graph; + + // Extract unique node types + const nodeTypes = [...new Set(graphData.nodes.map(node => node.data.type))]; + const nodeCategories = [...new Set(graphData.nodes.map(node => node.data.category))]; + + // Generate colors for each type + const nodeColors = { + "BASE": "#ff0000", + "POPULATION": "#00ff00", + "INTERVENTION": "#9999ff", + "POPULATION_INTERVENTION": "#ff00ff", + }; + const nodeShapes = { + "Symbol": "round-rectangle", + "&": "rhomboid", + "|": "diamond", + "Not": "triangle", + "NoDataPreservingAnd": "rhomboid", + "NoDataPreservingOr": "diamond", + "NonSimplifiableAnd": "rhomboid", + "NonSimplifiableOr": "diamond", + "LeftDependentToggle": "octagon", + } + + // Initialize Cytoscape + var cy = cytoscape({ + container: document.getElementById('cy'), + elements: [...graphData.nodes, ...graphData.edges], + style: [ + { + selector: 'node', + style: { + 'label': function(ele) { + if (ele.data('type') === 'Symbol') { + if (ele.data('category') == 'BASE') { + return ele.data('class') + } + var label; + label = ele.data('concept')["concept_name"]; + var value = ele.data('value'); + var dosage = ele.data('dosage'); + var timing = ele.data('timing'); + var route = ele.data('route'); + + if (value) { + label += " " + value; + } + if (dosage) { + label += "\n" + dosage; + } + if (timing) { + label += "\n" + timing; + } + if (route) { + label += "\n[" + route + "]"; + } + return label; + + } + if (ele.data("is_sink")) { + return ele.data('category') + " [SINK]" + } + return ele.data('class') + }, + 'background-color': function(ele) { + return nodeColors[ele.data('category')] || '#666'; // Assign color based on 'type', with a default + }, + 'shape': function(ele) { + return nodeShapes[ele.data('type')] || 'star'; // Assign color based on 'type', with a default + }, + 'text-valign': 'center', + 'color': '#000000', + 'width': function(ele) { + return ele.data('type') === 'Symbol' ? '120px': '40px'; + }, + 'height': function(ele) { + return ele.data('type') === 'Symbol' ? '80px': '40px'; + }, + 'font-size': '10px', + 'text-wrap': 'wrap', + 'text-max-width': '120px' // Adjust width as needed + } + }, + { + selector: 'edge', + style: { + 'width': 2, + 'target-arrow-shape': 'triangle', // Set arrow shape to triangle + 'curve-style': 'bezier' // Makes the edge curved for better visibility of direction + } + } + ], + layout: { + name: 'klay', // Use 'klay' layout for better visualization + nodeDimensionsIncludeLabels: true, + fit: true, + padding: 20, + animate: true, + animationDuration: 500, + klay: { + spacing: 20, + direction: 'DOWN', + } + } + }); + + // Add event listener for node click + cy.on('tap', 'node', function(evt) { + hideTippys(cy); + const node = evt.target; + if (!node.tippy) { + node.tippy = createTippy(node); + } + node.tippy.show(); + }); + + // Hide popper when clicking on the canvas + cy.on('tap', function(evt) { + if (evt.target === cy) { + hideTippys(cy); + } + }); +} + +function createTippy(node) { + let content = ''; + + function formatData(data, prefix = '') { + for (let key in data) { + if (data.hasOwnProperty(key)) { + if (Array.isArray(data[key])) { + content += `${prefix}${key}:
`; + data[key].forEach((item, index) => { + content += `${prefix}${key}[${index}]:
`; + formatData(item, prefix + '   '); + }); + } else if (typeof data[key] === 'object' && data[key] !== null) { + content += `${prefix}${key}:
`; + formatData(data[key], prefix + '   '); + } else { + content += `${prefix}${key}: ${data[key]}
`; + } + } + } + } + + formatData(node.data()); + + let ref = node.popperRef(); // used only for positioning + let dummyDomEle = document.createElement('div'); + document.body.appendChild(dummyDomEle); // Ensure dummyDomEle has a parent + + return tippy(dummyDomEle, { + content: () => { + let div = document.createElement('div'); + div.innerHTML = content; + return div; + }, + placement: 'top', + hideOnClick: true, + interactive: true, + trigger: 'manual', + allowHTML: true, + getReferenceClientRect: () => ref.getBoundingClientRect() + }); +} + +function hideTippys(cy) { + cy.elements().forEach(ele => { + if (ele.tippy) { + ele.tippy.hide(); + } + }); +} + +loadRecommendations(); diff --git a/apps/graph/templates/index.html b/apps/graph/templates/index.html new file mode 100644 index 0000000..8b8d792 --- /dev/null +++ b/apps/graph/templates/index.html @@ -0,0 +1,61 @@ + + + + + CELIDA Execution Graphs + + + + + + + + + + + + + +
+
+
+ + + + diff --git a/apps/viz-backend/app/database.py b/apps/viz-backend/app/database.py index 4ef123f..036c9c8 100644 --- a/apps/viz-backend/app/database.py +++ b/apps/viz-backend/app/database.py @@ -1,11 +1,17 @@ -from settings import get_config +from urllib.parse import quote + +from settings import config from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker +connection_dict = config.omop.model_dump() +connection_dict["user"] = quote(connection_dict["user"]) +connection_dict["password"] = quote(connection_dict["password"]) + connection_string = ( "postgresql+psycopg://{user}:{password}@{host}:{port}/{database}".format( - **get_config().omop.model_dump() + **connection_dict ) ) @@ -13,7 +19,7 @@ connection_string, pool_pre_ping=True, connect_args={ - "options": "-csearch_path={}".format(get_config().omop.db_schema), + "options": "-csearch_path={}".format(config.omop.data_schema), }, ) diff --git a/apps/viz-backend/app/main.py b/apps/viz-backend/app/main.py index 901035a..12b3bb9 100644 --- a/apps/viz-backend/app/main.py +++ b/apps/viz-backend/app/main.py @@ -1,9 +1,12 @@ +import json +import re from typing import List from database import SessionLocal -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from models import Interval, RecommendationRun +from models import Interval, Recommendation, RecommendationRun +from settings import config from sqlalchemy import text from sqlalchemy.orm import Session @@ -19,6 +22,20 @@ ) +# Ensure schema name is a valid identifier +def is_valid_identifier(identifier: str) -> bool: + """ + Check if a string is a valid identifier. + """ + return re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier) is not None + + +result_schema = config.omop.result_schema + +if not is_valid_identifier(result_schema): + raise ValueError("Invalid schema name") + + def get_db() -> Session: """ Get a database session. @@ -30,6 +47,49 @@ def get_db() -> Session: db.close() +@app.get("/recommendation/list", response_model=List[Recommendation]) +def get_recommendations(db: Session = Depends(get_db)) -> List[Recommendation]: + """ + Get all recommendations. + """ + + result = db.execute( + text( + f""" + SELECT recommendation_id, recommendation_name, recommendation_title, recommendation_url, + recommendation_version, recommendation_package_version, create_datetime + FROM {result_schema}.recommendation + """ # nosec: result_schema is checked above (is_valid_identifier) + ) + ) + return result.fetchall() + + +@app.get("/recommendation/{recommendation_id}/execution_graph") +def get_execution_graph(recommendation_id: int, db: Session = Depends(get_db)) -> dict: + """ + Get the execution graph for a specific recommendation by ID. + """ + result = db.execute( + text( + f""" + SELECT recommendation_execution_graph + FROM {result_schema}.recommendation + WHERE recommendation_id = :recommendation_id + """ # nosec: result_schema is checked above (is_valid_identifier) + ), + {"recommendation_id": recommendation_id}, + ).fetchone() + + if not result: + raise HTTPException(status_code=404, detail="Recommendation not found") + + # Decode the bytes to a string and parse it as JSON + execution_graph = json.loads(result.recommendation_execution_graph.decode("utf-8")) + + return {"recommendation_execution_graph": execution_graph} + + @app.get("/execution_runs", response_model=List[RecommendationRun]) def get_execution_runs(db: Session = Depends(get_db)) -> dict: """ diff --git a/apps/viz-backend/app/models.py b/apps/viz-backend/app/models.py index 216bcb3..319e512 100644 --- a/apps/viz-backend/app/models.py +++ b/apps/viz-backend/app/models.py @@ -3,6 +3,20 @@ from pydantic import BaseModel +class Recommendation(BaseModel): + """ + Represents a single recommendation. + """ + + recommendation_id: int + recommendation_name: str + recommendation_title: str + recommendation_url: str + recommendation_version: str | None = None + recommendation_package_version: str | None = None + create_datetime: datetime + + class RecommendationRun(BaseModel): """ Represents a single recommendation run. diff --git a/apps/viz-backend/app/settings.py b/apps/viz-backend/app/settings.py index e1ee1fd..d8e3a00 100644 --- a/apps/viz-backend/app/settings.py +++ b/apps/viz-backend/app/settings.py @@ -12,7 +12,8 @@ class OMOPSettings(BaseModel): user: str password: str database: str - db_schema: str = Field(alias="schema") + data_schema: str = Field(alias="data_schema", default="cds_cdm") + result_schema: str = Field(alias="result_schema", default="celida") model_config = ConfigDict(populate_by_name=True) diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 55e357a..6f505ae 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -1,10 +1,11 @@ import hashlib +import json import logging from datetime import datetime import pandas as pd import sqlalchemy -from sqlalchemy import and_, insert, select +from sqlalchemy import and_, insert, select, update from execution_engine import __version__ from execution_engine.builder import ExecutionEngineBuilder @@ -229,15 +230,27 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None result = con.execute(query) recommendation.id = result.fetchone().recommendation_id - con.commit() + for pi_pair in recommendation.population_intervention_pairs(): + self.register_population_intervention_pair( + pi_pair, recommendation_id=recommendation.id + ) - for pi_pair in recommendation.population_intervention_pairs(): - self.register_population_intervention_pair( - pi_pair, recommendation_id=recommendation.id - ) + for criterion in pi_pair.flatten(): + self.register_criterion(criterion) + + with self._db.begin() as con: + # update recommendation with execution graph (now that criterion & pi pair is are known) + rec_graph: bytes = json.dumps( + recommendation.execution_graph().to_cytoscape_dict(), sort_keys=True + ).encode() + + update_query = ( + update(recommendation_table) + .where(recommendation_table.recommendation_id == recommendation.id) + .values(recommendation_execution_graph=rec_graph) + ) - for criterion in pi_pair.flatten(): - self.register_criterion(criterion) + con.execute(update_query) def register_population_intervention_pair( self, pi_pair: PopulationInterventionPair, recommendation_id: int @@ -272,8 +285,6 @@ def register_population_intervention_pair( result = con.execute(query) pi_pair.id = result.fetchone().pi_pair_id - con.commit() - def register_criterion(self, criterion: Criterion) -> None: """ Registers the Criterion in the result database. diff --git a/execution_engine/execution_graph/graph.py b/execution_engine/execution_graph/graph.py index 1ed1c9f..7c3d5f7 100644 --- a/execution_engine/execution_graph/graph.py +++ b/execution_engine/execution_graph/graph.py @@ -51,6 +51,12 @@ def is_sink_of_category( return True + def is_sink(self, expr: logic.Expr) -> bool: + """ + Check if a node is a sink node of the graph. + """ + return self.out_degree(expr) == 0 + @classmethod def from_expression( cls, expr: logic.Expr, base_criterion: Criterion @@ -149,6 +155,88 @@ def sink_node(self, category: CohortCategory | None = None) -> logic.Expr: return sink_nodes[0] + def to_cytoscape_dict(self) -> dict: + """ + Convert the graph to a dictionary that can be used by Cytoscape.js. + """ + nodes = [] + edges = [] + + for node in self.nodes(): + # Ensure all node attributes are serializable + + node_data = { + "data": { + "id": id(node), + "label": str(node), + "class": ( + node.criterion.__class__.__name__ + if isinstance(node, logic.Symbol) + else node.__class__.__name__ + ), + "type": ( + node._repr_join_str + if hasattr(node, "_repr_join_str") + and node._repr_join_str is not None + else node.__class__.__name__ + ), + "category": self.nodes[node][ + "category" + ].value, # Assuming 'value' is serializable + "store_result": str( + self.nodes[node]["store_result"] + ), # Convert to string if necessary + "is_sink": self.is_sink(node), + "bind_params": self.nodes[node]["bind_params"], + } + } + + if isinstance(node, logic.Symbol): + + node_data["data"]["criterion_id"] = node.criterion._id + + def criterion_attr(attr: str) -> str | None: + if ( + hasattr(node.criterion, attr) + and getattr(node.criterion, attr) is not None + ): + return str(getattr(node.criterion, attr)) + return None + + if node.criterion.concept is not None: + node_data["data"].update( + { + "concept": ( + node.criterion.concept.model_dump() + if node.criterion.concept is not None + else None + ), + "value": criterion_attr("value"), + "timing": criterion_attr("timing"), + "dose": criterion_attr("dose"), + "route": criterion_attr("route"), + } + ) + + if self.nodes[node]["category"] == CohortCategory.BASE: + node_data["data"]["base_criterion"] = str( + node.criterion + ) # Ensure this is serializable + + nodes.append(node_data) + + for edge in self.edges(): + edges.append( + { + "data": { + "source": id(edge[0]), + "target": id(edge[1]), + } + } + ) + + return {"nodes": nodes, "edges": edges} + def plot(self) -> None: """ Plot the graph. diff --git a/execution_engine/omop/cohort/population_intervention_pair.py b/execution_engine/omop/cohort/population_intervention_pair.py index 626dad8..06b6454 100644 --- a/execution_engine/omop/cohort/population_intervention_pair.py +++ b/execution_engine/omop/cohort/population_intervention_pair.py @@ -110,7 +110,9 @@ def execution_graph(self) -> ExecutionGraph: ) pi_graph = ExecutionGraph.from_expression(pi, self._base_criterion) - assert self._id is not None, "Population/intervention pair id not set" + if self._id is None: + raise ValueError("Population/intervention pair ID not set") + # todo: should we supply self instead of self._id? pi_graph.set_sink_nodes_store(bind_params=dict(pi_pair_id=self._id)) diff --git a/execution_engine/omop/concepts.py b/execution_engine/omop/concepts.py index 0a03060..cd33fbc 100644 --- a/execution_engine/omop/concepts.py +++ b/execution_engine/omop/concepts.py @@ -34,6 +34,9 @@ def __str__(self) -> str: """ Returns a string representation of the concept. """ + if self.vocabulary_id == "UCUM": + return str(self.concept_code) + return str(self.concept_name) def is_custom(self) -> bool: diff --git a/execution_engine/omop/criterion/concept.py b/execution_engine/omop/criterion/concept.py index 0fd8cda..9eb51e3 100644 --- a/execution_engine/omop/criterion/concept.py +++ b/execution_engine/omop/criterion/concept.py @@ -35,6 +35,10 @@ class ConceptCriterion(Criterion, ABC): """ + _concept: Concept + _value = None + _timing = None + def __init__( self, category: CohortCategory, @@ -71,6 +75,16 @@ def concept(self) -> Concept: """Get the concept associated with this Criterion""" return self._concept + @property + def value(self) -> Value | None: + """Get the value associated with this Criterion""" + return self._value + + @property + def timing(self) -> Timing | None: + """Get the timing associated with this Criterion""" + return self._timing + def _sql_filter_concept( self, query: Select, override_concept_id: int | None = None ) -> Select: diff --git a/execution_engine/omop/criterion/drug_exposure.py b/execution_engine/omop/criterion/drug_exposure.py index f5d082e..3cdc2be 100644 --- a/execution_engine/omop/criterion/drug_exposure.py +++ b/execution_engine/omop/criterion/drug_exposure.py @@ -69,6 +69,16 @@ def concept(self) -> Concept: """Get the concept of the ingredient associated with this DrugExposure""" return self._ingredient_concept + @property + def dose(self) -> Dosage: + """Get the dose associated with this DrugExposure""" + return self._dose + + @property + def route(self) -> Concept | None: + """Get the route associated with this DrugExposure""" + return self._route + def is_weight_related(self) -> bool: """ Check if the criterion is weight related. diff --git a/execution_engine/omop/criterion/procedure_occurrence.py b/execution_engine/omop/criterion/procedure_occurrence.py index fd45410..1d80ae7 100644 --- a/execution_engine/omop/criterion/procedure_occurrence.py +++ b/execution_engine/omop/criterion/procedure_occurrence.py @@ -145,6 +145,8 @@ def description(self) -> str: Get a human-readable description of the criterion. """ + assert self._concept is not None, "Concept must be set" + parts = [f"concept={self._concept.concept_name}"] if self._timing is not None: parts.append(f"dose={str(self._timing)}") @@ -155,6 +157,8 @@ def dict(self) -> dict[str, Any]: """ Return a dictionary representation of the criterion. """ + assert self._concept is not None, "Concept must be set" + return { "exclude": self._exclude, "category": self._category.value, diff --git a/execution_engine/omop/criterion/visit_occurrence.py b/execution_engine/omop/criterion/visit_occurrence.py index c0ce249..c9525b2 100644 --- a/execution_engine/omop/criterion/visit_occurrence.py +++ b/execution_engine/omop/criterion/visit_occurrence.py @@ -4,6 +4,7 @@ from sqlalchemy.sql import Select from execution_engine.constants import CohortCategory, OMOPConcepts +from execution_engine.omop.concepts import Concept from execution_engine.omop.criterion.abstract import column_interval_type from execution_engine.util.interval import IntervalType @@ -26,6 +27,15 @@ def __init__(self) -> None: self._category = CohortCategory.BASE self._set_omop_variables_from_domain("visit") + self._concept = Concept( + concept_id=OMOPConcepts.VISIT_TYPE_STILL_PATIENT.value, + concept_name="Still patient", + concept_code="30", + domain_id="Visit", + vocabulary_id="UB04 Pt dis status", + concept_class_id="UB04 Pt dis status", + ) + def _sql_header( self, distinct_person: bool = True, person_id: int | None = None ) -> Select: diff --git a/execution_engine/omop/db/celida/tables.py b/execution_engine/omop/db/celida/tables.py index 6dd112a..092bfc8 100644 --- a/execution_engine/omop/db/celida/tables.py +++ b/execution_engine/omop/db/celida/tables.py @@ -50,6 +50,7 @@ class Recommendation(Base): # noqa: D101 String(64), index=True, unique=True ) recommendation_json = mapped_column(LargeBinary) + recommendation_execution_graph = mapped_column(LargeBinary) create_datetime: Mapped[datetime] diff --git a/execution_engine/omop/sqlclient.py b/execution_engine/omop/sqlclient.py index 82c90fc..f099e06 100644 --- a/execution_engine/omop/sqlclient.py +++ b/execution_engine/omop/sqlclient.py @@ -261,7 +261,7 @@ def log_query(self, query: Select | Insert, params: dict | None = None) -> None: """ Log the given query against the OMOP CDM database. """ - self._query_logger.info(self.compile_query(query, params)) + self._query_logger.info(self.compile_query(query, params) + "\n") def get_concept_info(self, concept_id: int) -> Concept: """Get the concept info for the given concept ID."""