From 70a9541ef02e59c13ba77699c8c550bc24cc820d Mon Sep 17 00:00:00 2001 From: kberket Date: Wed, 14 Feb 2024 16:05:29 -0500 Subject: [PATCH] fix open data store connect and close and address future warnings for pandas --- src/maggma/stores/open_data.py | 13 +++++++++++-- tests/stores/test_open_data.py | 22 +++++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/maggma/stores/open_data.py b/src/maggma/stores/open_data.py index d5c4de827..904b3bb65 100644 --- a/src/maggma/stores/open_data.py +++ b/src/maggma/stores/open_data.py @@ -252,7 +252,10 @@ def get_merged_items(self, to_dt: pd.DataFrame, from_dt: pd.DataFrame) -> pd.Dat merged = to_dt.merge(from_dt, on=self.key, how="left", suffixes=("", "_B")) for column in from_dt.columns: if column not in self.key: - merged[column].update(merged.pop(column + "_B")) + s = merged.pop(column + "_B") + s.name = column + merged.update(s) + merged.infer_objects(copy=False) return pd.concat( (merged[orig_columns], from_dt[~from_dt.set_index(self.key).index.isin(to_dt.set_index(self.key).index)]), ignore_index=True, @@ -351,7 +354,13 @@ def connect(self): raise RuntimeError(f"Bucket not present on AWS: {self.bucket}") # load index - super().update(self.retrieve_manifest()) + self.set_index_data(self.retrieve_manifest()) + + def close(self): + """Closes any connections.""" + if self._s3_client is not None: + self._s3_client.close() + self._s3_client = None def retrieve_manifest(self) -> pd.DataFrame: """Retrieves the contents of the index stored in S3. diff --git a/tests/stores/test_open_data.py b/tests/stores/test_open_data.py index a369a64b2..78ee6759e 100644 --- a/tests/stores/test_open_data.py +++ b/tests/stores/test_open_data.py @@ -12,6 +12,8 @@ from maggma.stores.open_data import OpenDataStore, PandasMemoryStore, S3IndexStore +pd.set_option("future.no_silent_downcasting", True) + # PandasMemoryStore tests @pytest.fixture() @@ -24,11 +26,13 @@ def memstore(): store.key: "mp-1", store.last_updated_field: datetime.utcnow(), "data": "asd", + "int_val": 1, }, { store.key: "mp-3", store.last_updated_field: datetime.utcnow(), "data": "sdf", + "int_val": 3, }, ] ) @@ -170,12 +174,14 @@ def test_pdmems_update(memstore): memstore.key: "mp-1", memstore.last_updated_field: datetime.utcnow(), "data": "boo", + "int_val": 1, } ] ) df2 = memstore.update(df) assert len(memstore._data) == 2 assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-1'"})["data"].iloc[0] == "boo" + assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-1'"})["int_val"].iloc[0] == 1 assert df2.equals(df) df = pd.DataFrame( [ @@ -183,12 +189,14 @@ def test_pdmems_update(memstore): memstore.key: "mp-2", memstore.last_updated_field: datetime.utcnow(), "data": "boo", + "int_val": 2, } ] ) df2 = memstore.update(df) assert len(memstore._data) == 3 assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-2'"})["data"].iloc[0] == "boo" + assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-2'"})["int_val"].iloc[0] == 2 assert df2.equals(df) @@ -238,7 +246,7 @@ def test_s3is_connect_retrieve_manifest(s3indexstore): assert s3is.count() == 0 -def test_s3is_store_manifest(s3indexstore): +def test_s3is_store_manifest(): with mock_s3(): conn = boto3.resource("s3", region_name="us-east-1") conn.create_bucket(Bucket="bucket2") @@ -254,6 +262,18 @@ def test_s3is_store_manifest(s3indexstore): assert not df.equals(s3is._data) +def test_s3is_close(s3indexstore): + s3indexstore.close() + assert len(s3indexstore.query()) == 1 # actions auto-reconnect + s3indexstore.update(pd.DataFrame([{"task_id": "mp-2", "last_updated": "now"}])) + assert len(s3indexstore.query()) == 2 + s3indexstore.close() + assert len(s3indexstore.query()) == 2 # actions auto-reconnect + s3indexstore.close() + s3indexstore.connect() + assert len(s3indexstore.query()) == 1 # explicit connect reloads manifest + + @pytest.fixture() def s3store(): with mock_s3():