Skip to content

Commit

Permalink
Merge pull request #897 from kbuma/bugfix/opendata_store_pickle
Browse files Browse the repository at this point in the history
fixing OpenDataStore to pickle correctly
  • Loading branch information
Jason Munro authored Dec 15, 2023
2 parents 98aba27 + e561f49 commit 57f488e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
60 changes: 58 additions & 2 deletions src/maggma/stores/open_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def connect(self, force_reset: bool = False):
self._load_index(force_reset=force_reset)

def __hash__(self):
return hash((self.collection_name, self.bucket, self.prefix))
return hash((self.collection_name, self.bucket, self.prefix, self.endpoint_url, self.manifest_key))

def __eq__(self, other: object) -> bool:
"""
Expand All @@ -126,7 +126,7 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, S3IndexStore):
return False

fields = ["collection_name", "bucket", "prefix", "last_updated_field"]
fields = ["collection_name", "bucket", "prefix", "endpoint_url", "manifest_key", "last_updated_field"]
return all(getattr(self, f) == getattr(other, f) for f in fields)


Expand All @@ -145,6 +145,12 @@ class OpenDataStore(S3Store):
def __init__(
self,
index: S3IndexStore,
bucket: str,
compress: bool = True,
endpoint_url: Optional[str] = None,
sub_dir: Optional[str] = None,
key: str = "fs_id",
searchable_fields: Optional[List[str]] = None,
object_file_extension: str = ".json.gz",
access_as_public_bucket: bool = False,
**kwargs,
Expand All @@ -153,18 +159,36 @@ def __init__(
Args:
index (S3IndexStore): The store that'll be used as the index, ie for queries pertaining to this store.
bucket: name of the bucket.
compress: compress files inserted into the store.
endpoint_url: this allows the interface with minio service
sub_dir: subdirectory of the S3 bucket to store the data.
key: main key to index on.
searchable_fields: fields to keep in the index store.
object_file_extension (str, optional): The extension used for the data stored in S3. Defaults to ".json.gz".
access_as_public_bucket (bool, optional): If True, the S3 bucket will be accessed without signing, ie as if it's a public bucket.
This is useful for end users. Defaults to False.
"""
self.index = index
self.bucket = bucket
self.compress = compress
self.endpoint_url = endpoint_url
self.sub_dir = sub_dir.strip("/") + "/" if sub_dir else ""
self.key = key
self.searchable_fields = searchable_fields if searchable_fields is not None else []
self.object_file_extension = object_file_extension
self.access_as_public_bucket = access_as_public_bucket
if access_as_public_bucket:
kwargs["s3_resource_kwargs"] = kwargs["s3_resource_kwargs"] if "s3_resource_kwargs" in kwargs else {}
kwargs["s3_resource_kwargs"]["config"] = Config(signature_version=UNSIGNED)

kwargs["index"] = index
kwargs["bucket"] = bucket
kwargs["compress"] = compress
kwargs["endpoint_url"] = endpoint_url
kwargs["sub_dir"] = sub_dir
kwargs["key"] = key
kwargs["searchable_fields"] = searchable_fields
kwargs["unpack_data"] = True
super().__init__(**kwargs)

Expand Down Expand Up @@ -256,3 +280,35 @@ def rebuild_index_from_data(self, docs: List[Dict]) -> List[Dict]:
all_index_docs.append(index_doc)
self.index.store_manifest(all_index_docs)
return all_index_docs

def __hash__(self):
return hash(
(
self.bucket,
self.compress,
self.endpoint_url,
self.key,
self.sub_dir,
)
)

def __eq__(self, other: object) -> bool:
"""
Check equality for OpenDataStore.
other: other OpenDataStore to compare with.
"""
if not isinstance(other, OpenDataStore):
return False

fields = [
"index",
"bucket",
"compress",
"endpoint_url",
"key",
"searchable_fields",
"sub_dir",
"last_updated_field",
]
return all(getattr(self, f) == getattr(other, f) for f in fields)
14 changes: 13 additions & 1 deletion tests/stores/test_open_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from datetime import datetime

import pickle
import boto3
import pytest
import orjson
Expand Down Expand Up @@ -288,7 +289,7 @@ def raise_exception_other(data):

def test_eq(memstore, s3store):
assert s3store == s3store
assert memstore != s3store
assert s3store != memstore


def test_count_subdir(s3store_w_subdir):
Expand Down Expand Up @@ -397,3 +398,14 @@ def test_no_bucket():
store = OpenDataStore(index=index, bucket="bucket2")
with pytest.raises(RuntimeError, match=r".*Bucket not present.*"):
store.connect()


def test_pickle(s3store_w_subdir):
sobj = pickle.dumps(s3store_w_subdir.index)
dobj = pickle.loads(sobj)
assert hash(dobj) == hash(s3store_w_subdir.index)
assert dobj == s3store_w_subdir.index
sobj = pickle.dumps(s3store_w_subdir)
dobj = pickle.loads(sobj)
assert hash(dobj) == hash(s3store_w_subdir)
assert dobj == s3store_w_subdir

0 comments on commit 57f488e

Please sign in to comment.