diff --git a/tests/drivers/asynch/test_cursor.py b/tests/drivers/asynch/test_cursor.py index 7d6d45c..91b83cb 100644 --- a/tests/drivers/asynch/test_cursor.py +++ b/tests/drivers/asynch/test_cursor.py @@ -2,11 +2,9 @@ from sqlalchemy.util.concurrency import greenlet_spawn from tests.testcase import AsynchSessionTestCase -from tests.util import run_async class CursorTestCase(AsynchSessionTestCase): - @run_async async def test_execute_without_context(self): raw = await self.session.bind.raw_connection() cur = await greenlet_spawn(lambda: raw.cursor()) @@ -20,7 +18,6 @@ async def test_execute_without_context(self): raw.close() - @run_async async def test_execute_with_context(self): rv = await self.session.execute( text('SELECT * FROM system.numbers LIMIT 1') @@ -28,7 +25,6 @@ async def test_execute_with_context(self): self.assertEqual(len(rv.fetchall()), 1) - @run_async async def test_check_iter_cursor(self): rv = await self.session.execute( text('SELECT number FROM system.numbers LIMIT 5') @@ -36,11 +32,10 @@ async def test_check_iter_cursor(self): self.assertListEqual(list(rv), [(x,) for x in range(5)]) - @run_async async def test_execute_with_stream(self): - conn = await self.get_connection() - async with conn.stream( - text("SELECT * FROM system.numbers LIMIT 10") + async with self.connection.stream( + text("SELECT * FROM system.numbers LIMIT 10"), + execution_options={'max_block_size': 1} ) as result: idx = 0 async for r in result: diff --git a/tests/drivers/asynch/test_insert.py b/tests/drivers/asynch/test_insert.py index e69737f..0b89703 100644 --- a/tests/drivers/asynch/test_insert.py +++ b/tests/drivers/asynch/test_insert.py @@ -4,11 +4,9 @@ from asynch.errors import TypeMismatchError from tests.testcase import AsynchSessionTestCase -from tests.util import run_async class NativeInsertTestCase(AsynchSessionTestCase): - @run_async async def test_rowcount_return1(self): metadata = self.metadata() table = Table( @@ -37,7 +35,6 @@ async def test_rowcount_return1(self): ) self.assertEqual(rv.rowcount, -1) - @run_async async def test_types_check(self): metadata = self.metadata() table = Table( diff --git a/tests/drivers/asynch/test_select.py b/tests/drivers/asynch/test_select.py index 86fb415..b1536c6 100644 --- a/tests/drivers/asynch/test_select.py +++ b/tests/drivers/asynch/test_select.py @@ -3,12 +3,9 @@ from clickhouse_sqlalchemy import engines, types, Table from tests.session import mocked_engine from tests.testcase import AsynchSessionTestCase -from tests.util import run_async class SanityTestCase(AsynchSessionTestCase): - - @run_async async def test_sanity(self): with mocked_engine(self.session) as engine: metadata = self.metadata() diff --git a/tests/drivers/test_clickhouse_dialect.py b/tests/drivers/test_clickhouse_dialect.py index 122b0ec..f7c439e 100644 --- a/tests/drivers/test_clickhouse_dialect.py +++ b/tests/drivers/test_clickhouse_dialect.py @@ -86,7 +86,7 @@ def test_get_view_names_with_schema(self): db_views = self.dialect.get_view_names(self.connection, test_database) self.assertNotIn(self.table.name, db_views) - def test_reflecttable(self): + def test_reflect_table(self): self.table.create(self.session.bind) meta = self.metadata() insp = inspect(self.session.bind) @@ -95,7 +95,7 @@ def test_reflecttable(self): self.assertEqual(self.table.name, reflected_table.name) - def test_reflecttable_with_schema(self): + def test_reflect_table_with_schema(self): # Imitates calling sequence for clients like Superset that look # across schemas. meta = self.metadata() @@ -146,9 +146,8 @@ class ClickHouseAsynchDialectTestCase(BaseAsynchTestCase): session = asynch_session - @run_async - async def setUp(self): - super().setUp() + def setUp(self): + super(ClickHouseAsynchDialectTestCase, self).setUp() self.test_metadata = self.metadata() self.table = Table( 'test_exists_table', @@ -156,12 +155,11 @@ async def setUp(self): Column('x', types.Int32, primary_key=True), engines.Memory() ) - await self.run_sync(self.test_metadata.drop_all) + run_async(self.connection.run_sync)(self.test_metadata.drop_all) - @run_async - async def tearDown(self): - await self.run_sync(self.test_metadata.drop_all) - super().tearDown() + def tearDown(self): + run_async(self.connection.run_sync)(self.test_metadata.drop_all) + super(ClickHouseAsynchDialectTestCase, self).tearDown() async def run_inspector_method(self, method, *args, **kwargs): def _run(conn): @@ -170,7 +168,6 @@ def _run(conn): return await self.run_sync(_run) - @run_async async def test_has_table(self): self.assertFalse( await self.run_inspector_method('has_table', self.table.name) @@ -182,7 +179,6 @@ async def test_has_table(self): await self.run_inspector_method('has_table', self.table.name) ) - @run_async async def test_has_table_with_schema(self): self.assertFalse( await self.run_inspector_method( @@ -199,7 +195,6 @@ async def test_has_table_with_schema(self): ) ) - @run_async async def test_get_table_names(self): await self.run_sync(self.test_metadata.create_all) @@ -207,7 +202,6 @@ async def test_get_table_names(self): self.assertIn(self.table.name, db_tables) - @run_async async def test_get_table_names_with_schema(self): await self.run_sync(self.test_metadata.create_all) @@ -218,7 +212,6 @@ async def test_get_table_names_with_schema(self): self.assertIn('columns', db_tables) - @run_async async def test_get_view_names(self): await self.run_sync(self.test_metadata.create_all) @@ -226,7 +219,6 @@ async def test_get_view_names(self): self.assertNotIn(self.table.name, db_views) - @run_async async def test_get_view_names_with_schema(self): await self.run_sync(self.test_metadata.create_all) @@ -237,8 +229,7 @@ async def test_get_view_names_with_schema(self): self.assertNotIn(self.table.name, db_views) - @run_async - async def test_reflecttable(self): + async def test_reflect_table(self): await self.run_sync(self.test_metadata.create_all) meta = self.metadata() @@ -247,8 +238,7 @@ async def test_reflecttable(self): self.assertEqual(self.table.name, reflected_table.name) - @run_async - async def test_reflecttable_with_schema(self): + async def test_reflect_table_with_schema(self): # Imitates calling sequence for clients like Superset that look # across schemas. meta = self.metadata() @@ -260,17 +250,15 @@ async def test_reflecttable_with_schema(self): if self.server_version >= (18, 16, 0): self.assertIsNone(reflected_table.engine) - @run_async async def test_get_schema_names(self): schemas = await self.run_inspector_method('get_schema_names') self.assertIn(test_database, schemas) - def test_columns_compilation(self): + async def test_columns_compilation(self): # should not raise UnsupportedCompilationError col = Column('x', types.Nullable(types.Int32)) self.assertEqual(str(col.type), 'Nullable(Int32)') - @run_async @require_server_version(19, 16, 2, is_async=True) async def test_empty_set_expr(self): numbers = Table( @@ -310,7 +298,8 @@ def test_server_version_http(self): def test_server_version_native(self): return self._test_server_version_any(system_native_uri) - @run_async + +class CachedServerVersionAsyncTestCase(BaseAsynchTestCase): async def test_server_version_asynch(self): engine_session = make_session(create_async_engine( system_asynch_uri, diff --git a/tests/testcase.py b/tests/testcase.py index f44c902..0d571aa 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -1,5 +1,5 @@ import re -from contextlib import contextmanager, asynccontextmanager +from contextlib import contextmanager from unittest import TestCase from sqlalchemy import MetaData, text @@ -97,46 +97,29 @@ class BaseAsynchTestCase(BaseTestCase): session = asynch_session @classmethod - @run_async - async def setUpClass(cls): + def setUpClass(cls): # System database is always present. - await system_asynch_session.execute( + run_async(system_asynch_session.execute)( text('DROP DATABASE IF EXISTS {}'.format(cls.database)) ) - await system_asynch_session.execute( + run_async(system_asynch_session.execute)( text('CREATE DATABASE {}'.format(cls.database)) ) version = ( - await system_asynch_session.execute(text('SELECT version()')) + run_async(system_asynch_session.execute)(text('SELECT version()')) ).fetchall() cls.server_version = tuple(int(x) for x in version[0][0].split('.')) - super(BaseTestCase, cls).setUpClass() - - @asynccontextmanager - async def create_table(self, table): - metadata = self.metadata() - await self.run_sync(metadata.drop_all) - await self.run_sync(metadata.create_all) - - try: - yield - finally: - await self.run_sync(metadata.drop_all) + def setUp(self): + self.connection = run_async(self.session.connection)() + super(BaseAsynchTestCase, self).setUp() - async def get_connection(self): - return await self.session.connection() + def _callTestMethod(self, method): + return run_async(method)() async def run_sync(self, f): - conn = await self.get_connection() - return await conn.run_sync(f) - - async def session_scalar(self, statement): - def wrapper(session): - return session.query(statement).scalar() - - return await self.session.run_sync(wrapper) + return await self.connection.run_sync(f) class HttpSessionTestCase(BaseTestCase): diff --git a/testsrequire.py b/testsrequire.py index baa9d53..10c1ffe 100644 --- a/testsrequire.py +++ b/testsrequire.py @@ -1,6 +1,7 @@ tests_require = [ 'pytest', + 'pytest-asyncio', 'sqlalchemy>=2.0.0,<2.1.0', 'greenlet>=2.0.1', 'alembic',