Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
xzkostyan committed Aug 6, 2024
1 parent b0fb8ba commit 0936a7e
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 66 deletions.
11 changes: 3 additions & 8 deletions tests/drivers/asynch/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from sqlalchemy.util.concurrency import greenlet_spawn

from tests.testcase import AsynchSessionTestCase
from tests.util import run_async


class CursorTestCase(AsynchSessionTestCase):
@run_async
async def test_execute_without_context(self):
raw = await self.session.bind.raw_connection()
cur = await greenlet_spawn(lambda: raw.cursor())
Expand All @@ -20,27 +18,24 @@ async def test_execute_without_context(self):

raw.close()

@run_async
async def test_execute_with_context(self):
rv = await self.session.execute(
text('SELECT * FROM system.numbers LIMIT 1')
)

self.assertEqual(len(rv.fetchall()), 1)

@run_async
async def test_check_iter_cursor(self):
rv = await self.session.execute(
text('SELECT number FROM system.numbers LIMIT 5')
)

self.assertListEqual(list(rv), [(x,) for x in range(5)])

@run_async
async def test_execute_with_stream(self):
conn = await self.get_connection()
async with conn.stream(
text("SELECT * FROM system.numbers LIMIT 10")
async with self.connection.stream(
text("SELECT * FROM system.numbers LIMIT 10"),
execution_options={'max_block_size': 1}
) as result:
idx = 0
async for r in result:
Expand Down
3 changes: 0 additions & 3 deletions tests/drivers/asynch/test_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from asynch.errors import TypeMismatchError

from tests.testcase import AsynchSessionTestCase
from tests.util import run_async


class NativeInsertTestCase(AsynchSessionTestCase):
@run_async
async def test_rowcount_return1(self):
metadata = self.metadata()
table = Table(
Expand Down Expand Up @@ -37,7 +35,6 @@ async def test_rowcount_return1(self):
)
self.assertEqual(rv.rowcount, -1)

@run_async
async def test_types_check(self):
metadata = self.metadata()
table = Table(
Expand Down
3 changes: 0 additions & 3 deletions tests/drivers/asynch/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
from clickhouse_sqlalchemy import engines, types, Table
from tests.session import mocked_engine
from tests.testcase import AsynchSessionTestCase
from tests.util import run_async


class SanityTestCase(AsynchSessionTestCase):

@run_async
async def test_sanity(self):
with mocked_engine(self.session) as engine:
metadata = self.metadata()
Expand Down
37 changes: 13 additions & 24 deletions tests/drivers/test_clickhouse_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_get_view_names_with_schema(self):
db_views = self.dialect.get_view_names(self.connection, test_database)
self.assertNotIn(self.table.name, db_views)

def test_reflecttable(self):
def test_reflect_table(self):
self.table.create(self.session.bind)
meta = self.metadata()
insp = inspect(self.session.bind)
Expand All @@ -95,7 +95,7 @@ def test_reflecttable(self):

self.assertEqual(self.table.name, reflected_table.name)

def test_reflecttable_with_schema(self):
def test_reflect_table_with_schema(self):
# Imitates calling sequence for clients like Superset that look
# across schemas.
meta = self.metadata()
Expand Down Expand Up @@ -146,22 +146,20 @@ class ClickHouseAsynchDialectTestCase(BaseAsynchTestCase):

session = asynch_session

@run_async
async def setUp(self):
super().setUp()
def setUp(self):
super(ClickHouseAsynchDialectTestCase, self).setUp()
self.test_metadata = self.metadata()
self.table = Table(
'test_exists_table',
self.test_metadata,
Column('x', types.Int32, primary_key=True),
engines.Memory()
)
await self.run_sync(self.test_metadata.drop_all)
run_async(self.connection.run_sync)(self.test_metadata.drop_all)

@run_async
async def tearDown(self):
await self.run_sync(self.test_metadata.drop_all)
super().tearDown()
def tearDown(self):
run_async(self.connection.run_sync)(self.test_metadata.drop_all)
super(ClickHouseAsynchDialectTestCase, self).tearDown()

async def run_inspector_method(self, method, *args, **kwargs):
def _run(conn):
Expand All @@ -170,7 +168,6 @@ def _run(conn):

return await self.run_sync(_run)

@run_async
async def test_has_table(self):
self.assertFalse(
await self.run_inspector_method('has_table', self.table.name)
Expand All @@ -182,7 +179,6 @@ async def test_has_table(self):
await self.run_inspector_method('has_table', self.table.name)
)

@run_async
async def test_has_table_with_schema(self):
self.assertFalse(
await self.run_inspector_method(
Expand All @@ -199,15 +195,13 @@ async def test_has_table_with_schema(self):
)
)

@run_async
async def test_get_table_names(self):
await self.run_sync(self.test_metadata.create_all)

db_tables = await self.run_inspector_method('get_table_names')

self.assertIn(self.table.name, db_tables)

@run_async
async def test_get_table_names_with_schema(self):
await self.run_sync(self.test_metadata.create_all)

Expand All @@ -218,15 +212,13 @@ async def test_get_table_names_with_schema(self):

self.assertIn('columns', db_tables)

@run_async
async def test_get_view_names(self):
await self.run_sync(self.test_metadata.create_all)

db_views = await self.run_inspector_method('get_view_names')

self.assertNotIn(self.table.name, db_views)

@run_async
async def test_get_view_names_with_schema(self):
await self.run_sync(self.test_metadata.create_all)

Expand All @@ -237,8 +229,7 @@ async def test_get_view_names_with_schema(self):

self.assertNotIn(self.table.name, db_views)

@run_async
async def test_reflecttable(self):
async def test_reflect_table(self):
await self.run_sync(self.test_metadata.create_all)
meta = self.metadata()

Expand All @@ -247,8 +238,7 @@ async def test_reflecttable(self):

self.assertEqual(self.table.name, reflected_table.name)

@run_async
async def test_reflecttable_with_schema(self):
async def test_reflect_table_with_schema(self):
# Imitates calling sequence for clients like Superset that look
# across schemas.
meta = self.metadata()
Expand All @@ -260,17 +250,15 @@ async def test_reflecttable_with_schema(self):
if self.server_version >= (18, 16, 0):
self.assertIsNone(reflected_table.engine)

@run_async
async def test_get_schema_names(self):
schemas = await self.run_inspector_method('get_schema_names')
self.assertIn(test_database, schemas)

def test_columns_compilation(self):
async def test_columns_compilation(self):
# should not raise UnsupportedCompilationError
col = Column('x', types.Nullable(types.Int32))
self.assertEqual(str(col.type), 'Nullable(Int32)')

@run_async
@require_server_version(19, 16, 2, is_async=True)
async def test_empty_set_expr(self):
numbers = Table(
Expand Down Expand Up @@ -310,7 +298,8 @@ def test_server_version_http(self):
def test_server_version_native(self):
return self._test_server_version_any(system_native_uri)

@run_async

class CachedServerVersionAsyncTestCase(BaseAsynchTestCase):
async def test_server_version_asynch(self):
engine_session = make_session(create_async_engine(
system_asynch_uri,
Expand Down
39 changes: 11 additions & 28 deletions tests/testcase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from contextlib import contextmanager, asynccontextmanager
from contextlib import contextmanager
from unittest import TestCase

from sqlalchemy import MetaData, text
Expand Down Expand Up @@ -97,46 +97,29 @@ class BaseAsynchTestCase(BaseTestCase):
session = asynch_session

@classmethod
@run_async
async def setUpClass(cls):
def setUpClass(cls):
# System database is always present.
await system_asynch_session.execute(
run_async(system_asynch_session.execute)(
text('DROP DATABASE IF EXISTS {}'.format(cls.database))
)
await system_asynch_session.execute(
run_async(system_asynch_session.execute)(
text('CREATE DATABASE {}'.format(cls.database))
)

version = (
await system_asynch_session.execute(text('SELECT version()'))
run_async(system_asynch_session.execute)(text('SELECT version()'))
).fetchall()
cls.server_version = tuple(int(x) for x in version[0][0].split('.'))

super(BaseTestCase, cls).setUpClass()

@asynccontextmanager
async def create_table(self, table):
metadata = self.metadata()
await self.run_sync(metadata.drop_all)
await self.run_sync(metadata.create_all)

try:
yield
finally:
await self.run_sync(metadata.drop_all)
def setUp(self):
self.connection = run_async(self.session.connection)()
super(BaseAsynchTestCase, self).setUp()

async def get_connection(self):
return await self.session.connection()
def _callTestMethod(self, method):
return run_async(method)()

async def run_sync(self, f):
conn = await self.get_connection()
return await conn.run_sync(f)

async def session_scalar(self, statement):
def wrapper(session):
return session.query(statement).scalar()

return await self.session.run_sync(wrapper)
return await self.connection.run_sync(f)


class HttpSessionTestCase(BaseTestCase):
Expand Down
1 change: 1 addition & 0 deletions testsrequire.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

tests_require = [
'pytest',
'pytest-asyncio',
'sqlalchemy>=2.0.0,<2.1.0',
'greenlet>=2.0.1',
'alembic',
Expand Down

0 comments on commit 0936a7e

Please sign in to comment.