Skip to content

Commit

Permalink
update tests, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 7, 2024
1 parent a75500c commit b97beb0
Showing 1 changed file with 55 additions and 52 deletions.
107 changes: 55 additions & 52 deletions py-polars/tests/unit/io/database/test_write.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pytest
from sqlalchemy import create_engine
Expand All @@ -17,23 +17,10 @@


@pytest.mark.write_disk()
@pytest.mark.parametrize(
"engine",
[
"sqlalchemy",
pytest.param(
"adbc",
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
),
),
],
)
class TestWriteDatabase:
"""Database write tests that share common pytest/parametrize options."""

def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> None:
def test_write_database_create(self, tmp_path: Path) -> None:
"""Test basic database table creation."""
df = pl.DataFrame(
{
Expand All @@ -43,18 +30,19 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db = str(tmp_path / "test_create.db")
test_db_uri = f"sqlite:///{test_db}"
table_name_stub = "test_create"

for idx, conn in enumerate(
(
test_db_uri,
create_engine(test_db_uri),
_open_adbc_connection(test_db_uri),
)
):
table_name = f"{table_name_stub}{idx}"
connections: list[tuple[Any, DbWriteEngine | None]] = [
(test_db_uri, "sqlalchemy"),
(create_engine(test_db_uri), None),
]
if sys.version_info >= (3, 9) and sys.platform != "win32":
connections.append((test_db_uri, "sqlalchemy"))
connections.append((_open_adbc_connection(test_db_uri), None))

for idx, (conn, engine) in enumerate(connections):
table_name = f"test_create{idx}"
assert (
df.write_database(
table_name=table_name,
Expand All @@ -72,9 +60,7 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N
if hasattr(conn, "close"):
conn.close()

def test_write_database_append_replace(
self, engine: DbWriteEngine, tmp_path: Path
) -> None:
def test_write_database_append_replace(self, tmp_path: Path) -> None:
"""Test append/replace ops against existing database table."""
df = pl.DataFrame(
{
Expand All @@ -83,20 +69,20 @@ def test_write_database_append_replace(
"other": [5.5, 7.0, None],
}
)

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db = str(tmp_path / "test_append.db")
test_db_uri = f"sqlite:///{test_db}"
table_name_stub = "test_create"

for idx, conn in enumerate(
(
test_db_uri,
create_engine(test_db_uri),
_open_adbc_connection(test_db_uri),
)
):
table_name = f"{table_name_stub}{idx}"
connections: list[tuple[Any, DbWriteEngine | None]] = [
(test_db_uri, "sqlalchemy"),
(create_engine(test_db_uri), None),
]
if sys.version_info >= (3, 9) and sys.platform != "win32":
connections.append((test_db_uri, "sqlalchemy"))
connections.append((_open_adbc_connection(test_db_uri), None))

for idx, (conn, engine) in enumerate(connections):
table_name = f"test_create{idx}"
assert (
df.write_database(
table_name=table_name,
Expand Down Expand Up @@ -146,23 +132,27 @@ def test_write_database_append_replace(
if hasattr(conn, "close"):
conn.close()

def test_write_database_create_quoted_tablename(
self, engine: DbWriteEngine, tmp_path: Path
) -> None:
def test_write_database_create_quoted_tablename(self, tmp_path: Path) -> None:
"""Test parsing/handling of quoted database table names."""
df = pl.DataFrame({"col x": [100, 200, 300], "col y": ["a", "b", "c"]})

df = pl.DataFrame(
{
"col x": [100, 200, 300],
"col y": ["a", "b", "c"],
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db = str(tmp_path / "test_create_quoted.db")
test_db_uri = f"sqlite:///{test_db}"

for idx, conn in enumerate(
(
test_db_uri,
create_engine(test_db_uri),
_open_adbc_connection(test_db_uri),
)
):
connections: list[tuple[Any, DbWriteEngine | None]] = [
(test_db_uri, "sqlalchemy"),
(create_engine(test_db_uri), None),
]
if sys.version_info >= (3, 9) and sys.platform != "win32":
connections.append((test_db_uri, "sqlalchemy"))
connections.append((_open_adbc_connection(test_db_uri), None))

for idx, (conn, engine) in enumerate(connections):
# table name has some special chars, so requires quoting, and
# is expliocitly qualified with the sqlite 'main' schema
qualified_table_name = f'main."test-append-{engine}{idx}"'
Expand Down Expand Up @@ -192,6 +182,19 @@ def test_write_database_create_quoted_tablename(
if hasattr(conn, "close"):
conn.close()

@pytest.mark.parametrize(
"engine",
[
"sqlalchemy",
pytest.param(
"adbc",
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
),
),
],
)
def test_write_database_errors(self, engine: DbWriteEngine, tmp_path: Path) -> None:
"""Confirm that expected errors are raised."""
df = pl.DataFrame({"colx": [1, 2, 3]})
Expand Down

0 comments on commit b97beb0

Please sign in to comment.