Skip to content

Commit

Permalink
Merge pull request #170 from mesoscope/feature/firebase-upload-gradients
Browse files Browse the repository at this point in the history
Feature/firebase upload gradients
  • Loading branch information
rugeli authored Aug 15, 2023
2 parents 8ab1881 + 15a12f8 commit e9211d9
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 12 deletions.
92 changes: 83 additions & 9 deletions cellpack/autopack/DBRecipeHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def as_dict(self):
data["regions"] = self.regions
return data

@staticmethod
def get_gradient_reference(downloaded_data, db):
if "gradient" in downloaded_data and db.is_reference(
downloaded_data["gradient"]
):
gradient_key = downloaded_data["gradient"]
downloaded_data["gradient"], _ = db.get_doc_by_ref(gradient_key)

@staticmethod
def get_reference_data(key_or_dict, db):
"""
Expand All @@ -66,12 +74,14 @@ def get_reference_data(key_or_dict, db):
if DataDoc.is_key(key_or_dict) and db.is_reference(key_or_dict):
key = key_or_dict
downloaded_data, _ = db.get_doc_by_ref(key)
CompositionDoc.get_gradient_reference(downloaded_data, db)
return downloaded_data, None
elif key_or_dict and isinstance(key_or_dict, dict):
object_dict = key_or_dict
if "object" in object_dict and db.is_reference(object_dict["object"]):
key = object_dict["object"]
downloaded_data, _ = db.get_doc_by_ref(key)
CompositionDoc.get_gradient_reference(downloaded_data, db)
return downloaded_data, key
return {}, None

Expand All @@ -96,19 +106,40 @@ def resolve_db_regions(self, db_data, db):
):
self.resolve_db_regions(downloaded_data, db)

@staticmethod
def gradient_list_to_dict(prep_recipe_data):
"""
Convert gradient list to dict for resolve_local_regions
"""
if "gradients" in prep_recipe_data and isinstance(
prep_recipe_data["gradients"], list
):
gradient_dict = {}
for gradient in prep_recipe_data["gradients"]:
gradient_dict[gradient["name"]] = gradient
prep_recipe_data["gradients"] = gradient_dict

def resolve_local_regions(self, local_data, recipe_data, db):
"""
Recursively resolves the regions of a composition from local data.
Restructure the local data to match the db data.
"""
unpack_recipe_data = DBRecipeHandler.prep_data_for_db(recipe_data)
prep_recipe_data = ObjectDoc.convert_representation(unpack_recipe_data, db)
# `gradients` is a list, convert it to dict for easy access and replace
CompositionDoc.gradient_list_to_dict(prep_recipe_data)
if "object" in local_data and local_data["object"] is not None:
if DataDoc.is_key(local_data["object"]):
key_name = local_data["object"]
else:
key_name = local_data["object"]["name"]
local_data["object"] = prep_recipe_data["objects"][key_name]
if "gradient" in local_data["object"] and isinstance(
local_data["object"]["gradient"], str
):
local_data["object"]["gradient"] = prep_recipe_data["gradients"][
local_data["object"]["gradient"]
]
for region_name in local_data["regions"]:
for index, key_or_dict in enumerate(local_data["regions"][region_name]):
if not DataDoc.is_key(key_or_dict):
Expand All @@ -121,6 +152,12 @@ def resolve_local_regions(self, local_data, recipe_data, db):
local_data["regions"][region_name][index][
"object"
] = prep_recipe_data["objects"][obj_item["name"]]
# replace gradient reference with gradient data
obj_data = local_data["regions"][region_name][index]["object"]
if "gradient" in obj_data and isinstance(obj_data["gradient"], str):
local_data["regions"][region_name][index]["object"][
"gradient"
] = prep_recipe_data["gradients"][obj_data["gradient"]]
else:
comp_name = local_data["regions"][region_name][index]
prep_comp_data = prep_recipe_data["composition"][comp_name]
Expand Down Expand Up @@ -209,14 +246,9 @@ def should_write(self, db, recipe_data):
if db_docs and len(db_docs) >= 1:
for doc in db_docs:
db_data = db.doc_to_dict(doc)
shallow_match = True
for item in CompositionDoc.SHALLOW_MATCH:
if db_data[item] != local_data[item]:
print(db_data[item], local_data[item])
shallow_match = False
break
if not shallow_match:
continue
if local_data["regions"] is None and db_data["regions"] is None:
# found a match, so shouldn't write
return False, db.doc_id(doc)
Expand Down Expand Up @@ -296,11 +328,29 @@ def should_write(self, db):
return None, None


class GradientDoc(DataDoc):
def __init__(self, settings):
super().__init__()
self.settings = settings

def should_write(self, db, grad_name):
docs = db.get_doc_by_name("gradients", grad_name)
if docs and len(docs) >= 1:
for doc in docs:
local_data = DBRecipeHandler.prep_data_for_db(db.doc_to_dict(doc))
db_data = db.doc_to_dict(doc)
difference = DeepDiff(db_data, local_data, ignore_order=True)
if not difference:
return doc, db.doc_id(doc)
return None, None


class DBRecipeHandler(object):
def __init__(self, db_handler):
self.db = db_handler
self.objects_to_path_map = {}
self.comp_to_path_map = {}
self.grad_to_path_map = {}

@staticmethod
def is_nested_list(item):
Expand Down Expand Up @@ -355,13 +405,35 @@ def upload_data(self, collection, data, id=None):
doc = self.db.set_doc(collection, id, modified_data)
return id, self.db.create_path(collection, id)

def upload_gradients(self, gradients):
for gradient in gradients:
gradient_name = gradient["name"]
gradient_doc = GradientDoc(settings=gradient)
_, doc_id = gradient_doc.should_write(self.db, gradient_name)
if doc_id:
print(f"gradients/{gradient_name} is already in firestore")
self.grad_to_path_map[gradient_name] = self.db.create_path(
"gradients", doc_id
)
else:
_, grad_path = self.upload_data("gradients", gradient_doc.settings)
self.grad_to_path_map[gradient_name] = grad_path

def upload_objects(self, objects):
for obj_name in objects:
objects[obj_name]["name"] = obj_name
object_doc = ObjectDoc(name=obj_name, settings=objects[obj_name])
# modify a copy of objects to avoid key error when resolving local regions
modify_objects = copy.deepcopy(objects)
# replace gradient name with path to check if gradient exists in db
if "gradient" in modify_objects[obj_name]:
grad_name = modify_objects[obj_name]["gradient"]
modify_objects[obj_name]["gradient"] = self.grad_to_path_map[grad_name]
object_doc = ObjectDoc(name=obj_name, settings=modify_objects[obj_name])
_, doc_id = object_doc.should_write(self.db)
if doc_id:
print(f"objects/{object_doc.name} is already in firestore")
obj_path = self.db.create_path("objects", doc_id)
self.objects_to_path_map[obj_name] = obj_path
else:
_, obj_path = self.upload_data("objects", object_doc.as_dict())
self.objects_to_path_map[obj_name] = obj_path
Expand Down Expand Up @@ -416,18 +488,20 @@ def get_recipe_id(self, recipe_data):
"""
recipe_name = recipe_data["name"]
recipe_version = recipe_data["version"]
key = f"{recipe_name}_v{recipe_version}"
key = f"{recipe_name}_v-{recipe_version}"
return key

def upload_collections(self, recipe_meta_data, recipe_data):
"""
Separate collections from recipe data and upload them to db
"""
recipe_to_save = copy.deepcopy(recipe_meta_data)
gradients = recipe_data.get("gradients")
objects = recipe_data["objects"]
compositions = recipe_data["composition"]
# TODO: test gradients recipes
# gradients = recipe_data.get("gradients")
# save gradients to db
if gradients:
self.upload_gradients(gradients)
# save objects to db
self.upload_objects(objects)
# save comps to db
Expand Down
2 changes: 2 additions & 0 deletions cellpack/autopack/FirebaseHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def update_elements_in_array(doc_ref, index, new_item_ref, remove_item):

@staticmethod
def is_reference(path):
if not isinstance(path, str):
return False
if path is None:
return False
if path.startswith("firebase:"):
Expand Down
3 changes: 2 additions & 1 deletion cellpack/autopack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ def is_remote_path(file_path):
@param file_path: str
"""
for ele in DATABASE_NAME:
return ele in file_path
if ele in file_path:
return True


def convert_db_shortname_to_url(file_location):
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/loaders/migrate_v1_to_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def convert(old_recipe):
new_recipe["name"] = old_recipe["recipe"]["name"]
new_recipe["bounding_box"] = old_recipe["options"]["boundingBox"]
objects_dict = {}
# TODO: check if composition structure is correct
composition = {"space": {"regions": {}}}
if "cytoplasme" in old_recipe:
outer_most_region_array = []
Expand Down
32 changes: 30 additions & 2 deletions cellpack/tests/test_db_recipe_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cellpack.autopack.DBRecipeHandler import DBRecipeHandler
from cellpack.tests.mocks.mock_db import MockDB
from unittest.mock import MagicMock, patch

mock_db = MockDB({})

Expand Down Expand Up @@ -61,10 +62,28 @@ def test_upload_objects():
data = {"test": {"test_key": "test_value"}}
object_doc = DBRecipeHandler(mock_db)
object_doc.upload_objects(data)

assert object_doc.objects_to_path_map == {"test": "firebase:objects/test_id"}


def test_upload_objects_with_gradient():
data = {"test": {"test_key": "test_value", "gradient": "test_grad_name"}}
object_handler = DBRecipeHandler(mock_db)
object_handler.grad_to_path_map = {"test_grad_name": "firebase:gradients/test_id"}

with patch(
"cellpack.autopack.DBRecipeHandler.ObjectDoc", return_value=MagicMock()
) as mock_object_doc:
mock_object_doc.return_value.should_write.return_value = (
None,
"firebase:gradients/test_id",
)
object_handler.upload_objects(data)
mock_object_doc.assert_called()
called_with_settings = mock_object_doc.call_args.kwargs["settings"]
assert data["test"]["gradient"] == "test_grad_name"
assert called_with_settings["gradient"] == "firebase:gradients/test_id"


def test_upload_compositions():
composition = {
"space": {"regions": {"interior": ["A"]}},
Expand Down Expand Up @@ -95,6 +114,15 @@ def test_upload_compositions():
}


def test_upload_gradients():
data = [{"name": "test_grad_name", "test_key": "test_value"}]
gradient_doc = DBRecipeHandler(mock_db)
gradient_doc.upload_gradients(data)
assert gradient_doc.grad_to_path_map == {
"test_grad_name": "firebase:gradients/test_id"
}


def test_get_recipe_id():
recipe_data = {
"name": "test",
Expand All @@ -103,7 +131,7 @@ def test_get_recipe_id():
"composition": {},
}
recipe_doc = DBRecipeHandler(mock_db)
assert recipe_doc.get_recipe_id(recipe_data) == "test_v1.0.0"
assert recipe_doc.get_recipe_id(recipe_data) == "test_v-1.0.0"


def test_upload_collections():
Expand Down
21 changes: 21 additions & 0 deletions cellpack/tests/test_gradient_doc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from cellpack.autopack.DBRecipeHandler import GradientDoc
from cellpack.tests.mocks.mock_db import MockDB

mock_db = MockDB({})


def test_should_write_with_no_existing_doc():
gradient_doc = GradientDoc({"name": "test_grad_name", "test_key": "test_value"})
doc, doc_id = gradient_doc.should_write(mock_db, "test_grad_name")
assert doc_id is None
assert doc is None


def test_should_write_with_existing_doc():
existing_doc = {"name": "test_grad_name", "test_key": "test_value"}
mock_db.data = existing_doc
gradient_doc = GradientDoc({"name": "test_grad_name", "test_key": "test_value"})

doc, doc_id = gradient_doc.should_write(mock_db, "test_grad_name")
assert doc_id is not None
assert doc is not None

0 comments on commit e9211d9

Please sign in to comment.