From b97beb00dd2d1e492ae5351f2145961acbdc88b0 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Tue, 7 May 2024 17:47:27 +0400 Subject: [PATCH] update tests, lint --- .../tests/unit/io/database/test_write.py | 107 +++++++++--------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/py-polars/tests/unit/io/database/test_write.py b/py-polars/tests/unit/io/database/test_write.py index a2d7bb1ccb14c..ba7b2093d7cf5 100644 --- a/py-polars/tests/unit/io/database/test_write.py +++ b/py-polars/tests/unit/io/database/test_write.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest from sqlalchemy import create_engine @@ -17,23 +17,10 @@ @pytest.mark.write_disk() -@pytest.mark.parametrize( - "engine", - [ - "sqlalchemy", - pytest.param( - "adbc", - marks=pytest.mark.skipif( - sys.version_info < (3, 9) or sys.platform == "win32", - reason="adbc not available on Windows or <= Python 3.8", - ), - ), - ], -) class TestWriteDatabase: """Database write tests that share common pytest/parametrize options.""" - def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> None: + def test_write_database_create(self, tmp_path: Path) -> None: """Test basic database table creation.""" df = pl.DataFrame( { @@ -43,18 +30,19 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N } ) tmp_path.mkdir(exist_ok=True) - test_db = str(tmp_path / f"test_{engine}.db") + test_db = str(tmp_path / "test_create.db") test_db_uri = f"sqlite:///{test_db}" - table_name_stub = "test_create" - for idx, conn in enumerate( - ( - test_db_uri, - create_engine(test_db_uri), - _open_adbc_connection(test_db_uri), - ) - ): - table_name = f"{table_name_stub}{idx}" + connections: list[tuple[Any, DbWriteEngine | None]] = [ + (test_db_uri, "sqlalchemy"), + (create_engine(test_db_uri), None), + ] + if sys.version_info >= (3, 9) and sys.platform != "win32": + connections.append((test_db_uri, "sqlalchemy")) + connections.append((_open_adbc_connection(test_db_uri), None)) + + for idx, (conn, engine) in enumerate(connections): + table_name = f"test_create{idx}" assert ( df.write_database( table_name=table_name, @@ -72,9 +60,7 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N if hasattr(conn, "close"): conn.close() - def test_write_database_append_replace( - self, engine: DbWriteEngine, tmp_path: Path - ) -> None: + def test_write_database_append_replace(self, tmp_path: Path) -> None: """Test append/replace ops against existing database table.""" df = pl.DataFrame( { @@ -83,20 +69,20 @@ def test_write_database_append_replace( "other": [5.5, 7.0, None], } ) - tmp_path.mkdir(exist_ok=True) - test_db = str(tmp_path / f"test_{engine}.db") + test_db = str(tmp_path / "test_append.db") test_db_uri = f"sqlite:///{test_db}" - table_name_stub = "test_create" - for idx, conn in enumerate( - ( - test_db_uri, - create_engine(test_db_uri), - _open_adbc_connection(test_db_uri), - ) - ): - table_name = f"{table_name_stub}{idx}" + connections: list[tuple[Any, DbWriteEngine | None]] = [ + (test_db_uri, "sqlalchemy"), + (create_engine(test_db_uri), None), + ] + if sys.version_info >= (3, 9) and sys.platform != "win32": + connections.append((test_db_uri, "sqlalchemy")) + connections.append((_open_adbc_connection(test_db_uri), None)) + + for idx, (conn, engine) in enumerate(connections): + table_name = f"test_create{idx}" assert ( df.write_database( table_name=table_name, @@ -146,23 +132,27 @@ def test_write_database_append_replace( if hasattr(conn, "close"): conn.close() - def test_write_database_create_quoted_tablename( - self, engine: DbWriteEngine, tmp_path: Path - ) -> None: + def test_write_database_create_quoted_tablename(self, tmp_path: Path) -> None: """Test parsing/handling of quoted database table names.""" - df = pl.DataFrame({"col x": [100, 200, 300], "col y": ["a", "b", "c"]}) - + df = pl.DataFrame( + { + "col x": [100, 200, 300], + "col y": ["a", "b", "c"], + } + ) tmp_path.mkdir(exist_ok=True) - test_db = str(tmp_path / f"test_{engine}.db") + test_db = str(tmp_path / "test_create_quoted.db") test_db_uri = f"sqlite:///{test_db}" - for idx, conn in enumerate( - ( - test_db_uri, - create_engine(test_db_uri), - _open_adbc_connection(test_db_uri), - ) - ): + connections: list[tuple[Any, DbWriteEngine | None]] = [ + (test_db_uri, "sqlalchemy"), + (create_engine(test_db_uri), None), + ] + if sys.version_info >= (3, 9) and sys.platform != "win32": + connections.append((test_db_uri, "sqlalchemy")) + connections.append((_open_adbc_connection(test_db_uri), None)) + + for idx, (conn, engine) in enumerate(connections): # table name has some special chars, so requires quoting, and # is expliocitly qualified with the sqlite 'main' schema qualified_table_name = f'main."test-append-{engine}{idx}"' @@ -192,6 +182,19 @@ def test_write_database_create_quoted_tablename( if hasattr(conn, "close"): conn.close() + @pytest.mark.parametrize( + "engine", + [ + "sqlalchemy", + pytest.param( + "adbc", + marks=pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc not available on Windows or <= Python 3.8", + ), + ), + ], + ) def test_write_database_errors(self, engine: DbWriteEngine, tmp_path: Path) -> None: """Confirm that expected errors are raised.""" df = pl.DataFrame({"colx": [1, 2, 3]})