Skip to content

Commit

Permalink
Introduce collection nodes and fix local graphs API
Browse files Browse the repository at this point in the history
- Add collection nodes and edges to graph view

- Fix local graph by properly dealing with bi-directional relationships

- Add collection-level graphs

- Fix local graph tests when using collections
  • Loading branch information
ml-evs committed Jul 26, 2023
1 parent 4133d18 commit 012d150
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 48 deletions.
1 change: 0 additions & 1 deletion pydatalab/pydatalab/models/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def add_missing_collection_relationships(cls, values):
if len([d for d in values.get("relationships", []) if d.type == "collections"]) != len(
values.get("collections", [])
):
breakpoint()
raise RuntimeError("Relationships and collections mismatch")

return values
6 changes: 3 additions & 3 deletions pydatalab/pydatalab/routes/v0_1/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
collection = Blueprint("collections", __name__)


@collection.route("/collections/")
@collection.route("/collections")
def get_collections():

collections = flask_mongo.db.collections.aggregate(
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_collection(collection_id):
)


@collection.route("/collections/", methods=["PUT"])
@collection.route("/collections", methods=["PUT"])
def create_collection():
request_json = request.get_json() # noqa: F821 pylint: disable=undefined-variable
data = request_json.get("data", {})
Expand Down Expand Up @@ -301,7 +301,7 @@ def delete_collection(collection_id: str):
)


@collection.route("/search-collections/", methods=["GET"])
@collection.route("/search-collections", methods=["GET"])
def search_collections():
query = request.args.get("query", type=str)
nresults = request.args.get("nresults", default=100, type=int)
Expand Down
121 changes: 95 additions & 26 deletions pydatalab/pydatalab/routes/v0_1/graphs.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Set

from flask import jsonify
from flask import jsonify, request

from pydatalab.mongo import flask_mongo
from pydatalab.routes.utils import get_default_permissions


def get_graph_cy_format(item_id: Optional[str] = None):
def get_graph_cy_format(item_id: Optional[str] = None, collection_id: Optional[str] = None):

collection_id = request.args.get("collection_id", type=str)

if item_id is None:
if collection_id is not None:
collection_immutable_id = flask_mongo.db.collections.find_one(
{"collection_id": collection_id}, projection={"_id": 1}
)
if not collection_immutable_id:
raise RuntimeError("No collection {collection_id=} found.")
collection_immutable_id = collection_immutable_id["_id"]
query = {
"$and": [
{"relationships.immutable_id": collection_immutable_id},
{"relationships.type": "collections"},
]
}
else:
query = {}
all_documents = flask_mongo.db.items.find(
get_default_permissions(user_only=False),
{**query, **get_default_permissions(user_only=False)},
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)
node_ids = {document["item_id"] for document in all_documents}
node_ids: Set[str] = {document["item_id"] for document in all_documents}
all_documents.rewind()

else:
Expand All @@ -27,36 +44,73 @@ def get_graph_cy_format(item_id: Optional[str] = None):
)
)

node_ids = {document["item_id"] for document in all_documents}
node_ids = {document["item_id"] for document in all_documents} | {
relationship["item_id"]
for document in all_documents
for relationship in document.get("relationships", [])
}
if len(node_ids) > 1:
or_query = [{"item_id": id} for id in node_ids if id != item_id]
# query.extend([{"relationships.item_id": id} for id in node_ids if id != item_id])
next_shell = flask_mongo.db.items.find(
{
"$or": [
*[{"item_id": id} for id in node_ids if id != item_id],
*[{"relationships.item_id": id} for id in node_ids if id != item_id],
],
"$or": or_query,
**get_default_permissions(user_only=False),
},
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)

node_ids = node_ids | {document["item_id"] for document in next_shell}
all_documents.extend(next_shell)
node_ids = node_ids | {document["item_id"] for document in all_documents}

nodes = []
edges = []

# Collect the elements that have already been added to the graph, to avoid duplication
drawn_elements = set()
node_collections = set()
for document in all_documents:

nodes.append(
{
"data": {
"id": document["item_id"],
"name": document["name"],
"type": document["type"],
"special": document["item_id"] == item_id,
}
}
)
for relationship in document.get("relationships", []):
# only considering child-parent relationships
if relationship.get("type") == "collections" and not collection_id:
collection_data = flask_mongo.db.collections.find_one(
{
"_id": relationship["immutable_id"],
**get_default_permissions(user_only=False),
},
projection={"collection_id": 1, "title": 1, "type": 1},
)
if collection_data:
if relationship["immutable_id"] not in node_collections:
_id = f'Collection: {collection_data["collection_id"]}'
if _id not in drawn_elements:
nodes.append(
{
"data": {
"id": _id,
"name": collection_data["title"],
"type": collection_data["type"],
"shape": "triangle",
}
}
)
node_collections.add(relationship["immutable_id"])
drawn_elements.add(_id)

source = f'Collection: {collection_data["collection_id"]}'
target = document.get("item_id")
edges.append(
{
"data": {
"id": f"{source}->{target}",
"source": source,
"target": target,
"value": 1,
}
}
)
continue

if not document.get("relationships"):
continue
Expand All @@ -70,13 +124,28 @@ def get_graph_cy_format(item_id: Optional[str] = None):
source = relationship["item_id"]
if source not in node_ids:
continue
edges.append(
edge_id = f"{source}->{target}"
if edge_id not in drawn_elements:
drawn_elements.add(edge_id)
edges.append(
{
"data": {
"id": edge_id,
"source": source,
"target": target,
"value": 1,
}
}
)

if document["item_id"] not in drawn_elements:
drawn_elements.add(document["item_id"])
nodes.append(
{
"data": {
"id": f"{source}->{target}",
"source": source,
"target": target,
"value": 1,
"id": document["item_id"],
"name": document["name"],
"type": document["type"],
}
}
)
Expand Down
34 changes: 30 additions & 4 deletions pydatalab/tests/routers/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,33 @@ def test_simple_graph(client):
assert len(graph["edges"]) == 3

graph = client.get("/item-graph/child_1").json
# These values are currently incorrect: really want to traverse the graph but need to
# resolve relationships first
assert len(graph["nodes"]) == 1
assert len(graph["edges"]) == 0
assert len(graph["nodes"]) == 2
assert len(graph["edges"]) == 1

graph = client.get("/item-graph/parent").json
assert len(graph["nodes"]) == 4
assert len(graph["edges"]) == 3

collection_id = "testcoll"
collection_json = {
"data": {
"collection_id": collection_id,
"title": "Test title",
"starting_members": [
{"item_id": "parent"},
{"item_id": "child_1"},
{"item_id": "child_2"},
],
}
}
response = client.put("/collections", json=collection_json)
assert response.status_code == 201
assert response.json["status"] == "success"

graph = client.get(f"/item-graph?collection_id={collection_id}").json
assert len(graph["nodes"]) == 3
assert len(graph["edges"]) == 2

graph = client.get("/item-graph/parent").json
assert len(graph["nodes"]) == 5
assert len(graph["edges"]) == 8
10 changes: 5 additions & 5 deletions pydatalab/tests/routers/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,19 +394,19 @@ def test_create_cell(client, default_cell):
def test_create_collections(client, default_collection):

# Check no collections initially
response = client.get("/collections/")
response = client.get("/collections")
assert len(response.json["data"]) == 0, response.json
assert response.status_code == 200

# Create an empty collection
response = client.put("/collections/", json={"data": json.loads(default_collection.json())})
response = client.put("/collections", json={"data": json.loads(default_collection.json())})
assert response.status_code == 201, response.json
assert response.json["status"] == "success"
assert response.json["data"]["collection_id"] == "test_collection"
assert response.json["data"]["title"] == "My Test Collection"
assert response.json["data"]["num_items"] == 0

response = client.get("/collections/")
response = client.get("/collections")
assert response.status_code == 200
assert len(response.json["data"]) == 1
assert response.json["data"][0]["collection_id"] == "test_collection"
Expand All @@ -426,7 +426,7 @@ def test_create_collections(client, default_collection):
]
}
)
response = client.put("/collections/", json={"data": data})
response = client.put("/collections", json={"data": data})
assert response.status_code == 201, response.json
assert response.json["status"] == "success"
assert response.json["data"]["collection_id"] == "test_collection_2"
Expand Down Expand Up @@ -461,7 +461,7 @@ def test_create_collections(client, default_collection):
assert len(response.json["item_data"]["collections"]) == 0

# remake it for the next test
response = client.put("/collections/", json={"data": data})
response = client.put("/collections", json={"data": data})
assert response.status_code == 201, response.json
assert response.json["status"] == "success"
assert response.json["data"]["collection_id"] == "test_collection_2"
Expand Down
6 changes: 3 additions & 3 deletions webapp/src/components/CollectionInformation.vue
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
v-model="CollectionDescription"
></TinyMceInline>

<!-- <CollectionRelationshipVisualization :collection_id="collection_id" /> -->
<CollectionRelationshipVisualization :collection_id="collection_id" />
<FancyCollectionSampleTable :collection_id="collection_id" />
</div>
</template>
Expand All @@ -30,7 +30,7 @@ import { createComputedSetterForCollectionField } from "@/field_utils.js";
import FancyCollectionSampleTable from "@/components/FancyCollectionSampleTable";
import TinyMceInline from "@/components/TinyMceInline";
import Creators from "@/components/Creators";
// import CollectionRelationshipVisualization from "@/components/CollectionRelationshipVisualization";
import CollectionRelationshipVisualization from "@/components/CollectionRelationshipVisualization";
export default {
props: {
Expand All @@ -47,7 +47,7 @@ export default {
TinyMceInline,
FancyCollectionSampleTable,
Creators,
// CollectionRelationshipVisualization,
CollectionRelationshipVisualization,
},
};
</script>
58 changes: 58 additions & 0 deletions webapp/src/components/CollectionRelationshipVisualization.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<template>
<label class="mr-2">Relationships</label>
<div>
<ItemGraph
:graphData="graphData"
style="height: 150px"
:defaultGraphStyle="elk - layered - down"
/>
</div>
</template>

<script>
// import FormattedItemName from "@/components/FormattedItemName"
import ItemGraph from "@/components/ItemGraph";
import { getItemGraph } from "@/server_fetch_utils.js";
export default {
computed: {
graphData() {
return this.$store.state.itemGraphData;
},
},
props: {
collection_id: String,
},
async mounted() {
await getItemGraph({ item_id: null, collection_id: this.collection_id });
},
components: {
ItemGraph,
},
};
</script>

<style scoped>
.nav-link {
cursor: pointer;
}
.contents-item {
cursor: pointer;
}
.contents-blocktype {
font-style: italic;
color: gray;
margin-right: 1rem;
}
.contents-blocktitle {
color: #004175;
}
#contents-ol {
margin-bottom: 0rem;
padding-left: 1rem;
}
</style>
6 changes: 5 additions & 1 deletion webapp/src/components/ItemGraph.vue
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ const layoutOptions = {
export default {
props: {
graphData: Object,
defaultGraphStyle: {
type: String,
default: "elk-stress",
},
},
data() {
return {
graphStyle: "elk-stress",
graphStyle: this.defaultGraphStyle,
};
},
methods: {
Expand Down
2 changes: 1 addition & 1 deletion webapp/src/components/RelationshipVisualization.vue
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export default {
item_id: String,
},
async mounted() {
await getItemGraph(this.item_id);
await getItemGraph({ item_id: this.item_id });
},
components: {
// FormattedItemName
Expand Down
Loading

0 comments on commit 012d150

Please sign in to comment.