Skip to content

Commit

Permalink
feat(python): More robust handling of async database calls (pola-rs…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Mar 23, 2024
1 parent 5febd51 commit bc91c62
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 10 deletions.
21 changes: 12 additions & 9 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import re
import warnings
from collections.abc import Coroutine
from contextlib import suppress
from inspect import Parameter, isclass, signature
Expand Down Expand Up @@ -270,17 +269,21 @@ def _from_rows(

@staticmethod
def _run_async(co: Coroutine) -> Any: # type: ignore[type-arg]
"""Consolidate async event loop acquisition and coroutine/func execution."""
"""Run asynchronous code as if it was synchronous."""
import asyncio

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(co)
import nest_asyncio

nest_asyncio.apply()
except ModuleNotFoundError as _err:
msg = (
"Executing using async drivers requires the `nest_asyncio` package."
"\n\nPlease run: pip install nest_asyncio"
)
raise ModuleNotFoundError(msg) from None

return asyncio.run(co)

@staticmethod
def _inject_type_overrides(
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/meta/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def show_versions() -> None:
gevent: 24.2.1
hvplot: 0.9.2
matplotlib: 3.8.3
nest_asyncio: 1.6.0
numpy: 1.26.4
openpyxl: 3.1.2
pandas: 2.2.1
Expand Down Expand Up @@ -70,6 +71,7 @@ def _get_dependency_info() -> dict[str, str]:
"gevent",
"hvplot",
"matplotlib",
"nest_asyncio",
"numpy",
"openpyxl",
"pandas",
Expand Down
4 changes: 3 additions & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Changelog = "https://github.com/pola-rs/polars/releases"
[project.optional-dependencies]
# NOTE: keep this list in sync with show_versions() and requirements-dev.txt
adbc = ["adbc_driver_manager", "adbc_driver_sqlite"]
async = ["nest_asyncio"]
cloudpickle = ["cloudpickle"]
connectorx = ["connectorx >= 0.3.2"]
deltalake = ["deltalake >= 0.14.0"]
Expand All @@ -60,7 +61,7 @@ timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_syst
xlsx2csv = ["xlsx2csv >= 0.8.0"]
xlsxwriter = ["xlsxwriter"]
all = [
"polars[adbc,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,pyiceberg,sqlalchemy,timezone,xlsx2csv,xlsxwriter]",
"polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,pyiceberg,sqlalchemy,timezone,xlsx2csv,xlsxwriter]",
]

[tool.maturin]
Expand Down Expand Up @@ -94,6 +95,7 @@ module = [
"kuzu",
"matplotlib.*",
"moto.server",
"nest_asyncio",
"openpyxl",
"polars.polars",
"pyarrow.*",
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ hvplot>=0.9.1
matplotlib
# Other
gevent
nest_asyncio

# -------
# TOOLING
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import os
import sqlite3
import sys
Expand Down Expand Up @@ -925,3 +926,21 @@ def test_read_database_async(tmp_sqlite_db: Path) -> None:
execute_options=execute_opts,
)
assert_frame_equal(expected_frame, df)


async def _nested_async_test(tmp_sqlite_db: Path) -> pl.DataFrame:
async_engine = create_async_engine(f"sqlite+aiosqlite:///{tmp_sqlite_db}")
return pl.read_database(
query="SELECT id, name FROM test_data ORDER BY id",
connection=async_engine.connect(),
)


def test_read_database_async_nested(tmp_sqlite_db: Path) -> None:
# this tests validates that we can handle nested async calls; without the
# internal nested asyncio/loop detection & handling provided by `nest_asyncio`
# this test would raise the RuntimeError: "This event loop is already running".

expected_frame = pl.DataFrame({"id": [1, 2], "name": ["misc", "other"]})
df = asyncio.run(_nested_async_test(tmp_sqlite_db))
assert_frame_equal(expected_frame, df)

0 comments on commit bc91c62

Please sign in to comment.