Skip to content

Commit

Permalink
Merge pull request #920 from kbuma/bugfix/open_data_store_close
Browse files Browse the repository at this point in the history
fix open data store connect and close and address future warnings for pandas
  • Loading branch information
Jason Munro authored Feb 14, 2024
2 parents 9924540 + 70a9541 commit ae0e636
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/maggma/stores/open_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ def get_merged_items(self, to_dt: pd.DataFrame, from_dt: pd.DataFrame) -> pd.Dat
merged = to_dt.merge(from_dt, on=self.key, how="left", suffixes=("", "_B"))
for column in from_dt.columns:
if column not in self.key:
merged[column].update(merged.pop(column + "_B"))
s = merged.pop(column + "_B")
s.name = column
merged.update(s)
merged.infer_objects(copy=False)
return pd.concat(
(merged[orig_columns], from_dt[~from_dt.set_index(self.key).index.isin(to_dt.set_index(self.key).index)]),
ignore_index=True,
Expand Down Expand Up @@ -351,7 +354,13 @@ def connect(self):
raise RuntimeError(f"Bucket not present on AWS: {self.bucket}")

# load index
super().update(self.retrieve_manifest())
self.set_index_data(self.retrieve_manifest())

def close(self):
"""Closes any connections."""
if self._s3_client is not None:
self._s3_client.close()
self._s3_client = None

def retrieve_manifest(self) -> pd.DataFrame:
"""Retrieves the contents of the index stored in S3.
Expand Down
22 changes: 21 additions & 1 deletion tests/stores/test_open_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from maggma.stores.open_data import OpenDataStore, PandasMemoryStore, S3IndexStore

pd.set_option("future.no_silent_downcasting", True)


# PandasMemoryStore tests
@pytest.fixture()
Expand All @@ -24,11 +26,13 @@ def memstore():
store.key: "mp-1",
store.last_updated_field: datetime.utcnow(),
"data": "asd",
"int_val": 1,
},
{
store.key: "mp-3",
store.last_updated_field: datetime.utcnow(),
"data": "sdf",
"int_val": 3,
},
]
)
Expand Down Expand Up @@ -170,25 +174,29 @@ def test_pdmems_update(memstore):
memstore.key: "mp-1",
memstore.last_updated_field: datetime.utcnow(),
"data": "boo",
"int_val": 1,
}
]
)
df2 = memstore.update(df)
assert len(memstore._data) == 2
assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-1'"})["data"].iloc[0] == "boo"
assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-1'"})["int_val"].iloc[0] == 1
assert df2.equals(df)
df = pd.DataFrame(
[
{
memstore.key: "mp-2",
memstore.last_updated_field: datetime.utcnow(),
"data": "boo",
"int_val": 2,
}
]
)
df2 = memstore.update(df)
assert len(memstore._data) == 3
assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-2'"})["data"].iloc[0] == "boo"
assert memstore.query(criteria={"query": f"{memstore.key} == 'mp-2'"})["int_val"].iloc[0] == 2
assert df2.equals(df)


Expand Down Expand Up @@ -238,7 +246,7 @@ def test_s3is_connect_retrieve_manifest(s3indexstore):
assert s3is.count() == 0


def test_s3is_store_manifest(s3indexstore):
def test_s3is_store_manifest():
with mock_s3():
conn = boto3.resource("s3", region_name="us-east-1")
conn.create_bucket(Bucket="bucket2")
Expand All @@ -254,6 +262,18 @@ def test_s3is_store_manifest(s3indexstore):
assert not df.equals(s3is._data)


def test_s3is_close(s3indexstore):
s3indexstore.close()
assert len(s3indexstore.query()) == 1 # actions auto-reconnect
s3indexstore.update(pd.DataFrame([{"task_id": "mp-2", "last_updated": "now"}]))
assert len(s3indexstore.query()) == 2
s3indexstore.close()
assert len(s3indexstore.query()) == 2 # actions auto-reconnect
s3indexstore.close()
s3indexstore.connect()
assert len(s3indexstore.query()) == 1 # explicit connect reloads manifest


@pytest.fixture()
def s3store():
with mock_s3():
Expand Down

0 comments on commit ae0e636

Please sign in to comment.