Skip to content

Commit

Permalink
Merge pull request #89 from yassun7010/fix_recorder
Browse files Browse the repository at this point in the history
fix: recorder
  • Loading branch information
yassun7010 authored Feb 16, 2024
2 parents 48797cb + 6501d33 commit ca4c05e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 97 deletions.
34 changes: 20 additions & 14 deletions turu-core/src/turu/core/record/async_record_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,39 @@ def __init__(
],
):
self._recorder = recorder
self._cursor: turu.core.async_cursor.AsyncCursor = cursor
self.__record_taregt_cursor: turu.core.async_cursor.AsyncCursor = cursor

@property
def rowcount(self) -> int:
return self._cursor.rowcount
return self.__record_taregt_cursor.rowcount

@property
def arraysize(self) -> int:
return self._cursor.arraysize
return self.__record_taregt_cursor.arraysize

@arraysize.setter
def arraysize(self, size: int) -> None:
self._cursor.arraysize = size
self.__record_taregt_cursor.arraysize = size

async def close(self) -> None:
await self._cursor.close()
await self.__record_taregt_cursor.close()
self._recorder.close()

async def execute(
self, operation: str, parameters: Optional[Parameters] = None, /
) -> "AsyncRecordCursor[turu.core.cursor.GenericRowType, Parameters]":
self._cursor = await self._cursor.execute(operation, parameters)
self.__record_taregt_cursor = await self.__record_taregt_cursor.execute(
operation, parameters
)

return self

async def executemany(
self, operation: str, seq_of_parameters: "Sequence[Parameters]", /
) -> "AsyncRecordCursor[turu.core.cursor.GenericRowType, Parameters]":
self._cursor = await self._cursor.executemany(operation, seq_of_parameters)
self.__record_taregt_cursor = await self.__record_taregt_cursor.executemany(
operation, seq_of_parameters
)

return self

Expand All @@ -62,7 +66,9 @@ async def execute_map(
parameters: Optional[Parameters] = None,
/,
) -> "AsyncRecordCursor[turu.core.cursor.GenericNewRowType, Parameters]":
self._cursor = await self._cursor.execute_map(row_type, operation, parameters)
self.__record_taregt_cursor = await self.__record_taregt_cursor.execute_map(
row_type, operation, parameters
)

return cast(AsyncRecordCursor, self)

Expand All @@ -73,14 +79,14 @@ async def executemany_map(
seq_of_parameters: Sequence[Parameters],
/,
) -> "AsyncRecordCursor[turu.core.cursor.GenericNewRowType, Parameters]":
self._cursor = await self._cursor.executemany_map(
self.__record_taregt_cursor = await self.__record_taregt_cursor.executemany_map(
row_type, operation, seq_of_parameters
)

return cast(AsyncRecordCursor, self)

async def fetchone(self) -> Optional[turu.core.cursor.GenericRowType]:
row = await self._cursor.fetchone()
row = await self.__record_taregt_cursor.fetchone()
if row is not None:
self._recorder.record([row])

Expand All @@ -89,14 +95,14 @@ async def fetchone(self) -> Optional[turu.core.cursor.GenericRowType]:
async def fetchmany(
self, size: Optional[int] = None
) -> List[turu.core.cursor.GenericRowType]:
rows = await self._cursor.fetchmany(size)
rows = await self.__record_taregt_cursor.fetchmany(size)

self._recorder.record(rows)

return rows

async def fetchall(self) -> List[turu.core.cursor.GenericRowType]:
rows = await self._cursor.fetchall()
rows = await self.__record_taregt_cursor.fetchall()

self._recorder.record(rows)

Expand All @@ -108,7 +114,7 @@ def __iter__(
return self

async def __anext__(self) -> turu.core.cursor.GenericRowType:
row = await self._cursor.__anext__()
row = await self.__record_taregt_cursor.__anext__()

self._recorder.record([row])

Expand All @@ -118,4 +124,4 @@ def __getattr__(self, name):
def _method_missing(*args):
return args

return getattr(self._cursor, name, _method_missing)
return getattr(self.__record_taregt_cursor, name, _method_missing)
34 changes: 20 additions & 14 deletions turu-core/src/turu/core/record/record_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,39 @@ def __init__(
cursor: turu.core.cursor.Cursor[turu.core.cursor.GenericRowType, Parameters],
):
self._recorder = recorder
self._cursor: turu.core.cursor.Cursor = cursor
self.__record_taregt_cursor: turu.core.cursor.Cursor = cursor

@property
def rowcount(self) -> int:
return self._cursor.rowcount
return self.__record_taregt_cursor.rowcount

@property
def arraysize(self) -> int:
return self._cursor.arraysize
return self.__record_taregt_cursor.arraysize

@arraysize.setter
def arraysize(self, size: int) -> None:
self._cursor.arraysize = size
self.__record_taregt_cursor.arraysize = size

def close(self) -> None:
self._cursor.close()
self.__record_taregt_cursor.close()
self._recorder.close()

def execute(
self, operation: str, parameters: Optional[Parameters] = None, /
) -> "RecordCursor[turu.core.cursor.GenericRowType, Parameters]":
self._cursor = self._cursor.execute(operation, parameters)
self.__record_taregt_cursor = self.__record_taregt_cursor.execute(
operation, parameters
)

return self

def executemany(
self, operation: str, seq_of_parameters: "Sequence[Parameters]", /
) -> "RecordCursor[turu.core.cursor.GenericRowType, Parameters]":
self._cursor = self._cursor.executemany(operation, seq_of_parameters)
self.__record_taregt_cursor = self.__record_taregt_cursor.executemany(
operation, seq_of_parameters
)

return self

Expand All @@ -59,7 +63,9 @@ def execute_map(
parameters: Optional[Parameters] = None,
/,
) -> "RecordCursor[turu.core.cursor.GenericNewRowType, Parameters]":
self._cursor = self._cursor.execute_map(row_type, operation, parameters)
self.__record_taregt_cursor = self.__record_taregt_cursor.execute_map(
row_type, operation, parameters
)

return cast(RecordCursor, self)

Expand All @@ -70,14 +76,14 @@ def executemany_map(
seq_of_parameters: Sequence[Parameters],
/,
) -> "RecordCursor[turu.core.cursor.GenericNewRowType, Parameters]":
self._cursor = self._cursor.executemany_map(
self.__record_taregt_cursor = self.__record_taregt_cursor.executemany_map(
row_type, operation, seq_of_parameters
)

return cast(RecordCursor, self)

def fetchone(self) -> Optional[turu.core.cursor.GenericRowType]:
row = self._cursor.fetchone()
row = self.__record_taregt_cursor.fetchone()
if row is not None:
self._recorder.record([row])

Expand All @@ -86,14 +92,14 @@ def fetchone(self) -> Optional[turu.core.cursor.GenericRowType]:
def fetchmany(
self, size: Optional[int] = None
) -> List[turu.core.cursor.GenericRowType]:
rows = self._cursor.fetchmany(size)
rows = self.__record_taregt_cursor.fetchmany(size)

self._recorder.record(rows)

return rows

def fetchall(self) -> List[turu.core.cursor.GenericRowType]:
rows = self._cursor.fetchall()
rows = self.__record_taregt_cursor.fetchall()

self._recorder.record(rows)

Expand All @@ -103,7 +109,7 @@ def __iter__(self) -> "RecordCursor[turu.core.cursor.GenericRowType, Parameters]
return self

def __next__(self) -> turu.core.cursor.GenericRowType:
row = next(self._cursor)
row = next(self.__record_taregt_cursor)

self._recorder.record([row])

Expand All @@ -113,4 +119,4 @@ def __getattr__(self, name):
def _method_missing(*args):
return args

return getattr(self._cursor, name, _method_missing)
return getattr(self.__record_taregt_cursor, name, _method_missing)
42 changes: 8 additions & 34 deletions turu-snowflake/src/turu/snowflake/record/async_record_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AsyncRecordCursor( # type: ignore[override]
],
):
async def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
df = await self._sf_cursor.fetch_pandas_all(**kwargs)
df = await self.__sf_cursor.fetch_pandas_all(**kwargs)

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -35,7 +35,7 @@ async def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
async def fetch_pandas_batches(
self, **kwargs
) -> AsyncIterator[GenericPandasDataFrame]:
batches = self._sf_cursor.fetch_pandas_batches(**kwargs)
batches = self.__sf_cursor.fetch_pandas_batches(**kwargs)

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -50,7 +50,7 @@ async def fetch_pandas_batches(
yield batch

async def fetch_arrow_all(self) -> GenericPyArrowTable:
table = await self._sf_cursor.fetch_arrow_all()
table = await self.__sf_cursor.fetch_arrow_all()

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -65,7 +65,7 @@ async def fetch_arrow_all(self) -> GenericPyArrowTable:
return table

async def fetch_arrow_batches(self) -> AsyncIterator[GenericPyArrowTable]:
batches = self._sf_cursor.fetch_arrow_batches()
batches = self.__sf_cursor.fetch_arrow_batches()

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -79,38 +79,12 @@ async def fetch_arrow_batches(self) -> AsyncIterator[GenericPyArrowTable]:
async for batch in batches:
yield batch

def use_warehouse(self, warehouse: str, /) -> "AsyncRecordCursor":
"""Use a warehouse in cursor."""

self._sf_cursor.use_warehouse(warehouse)

return self

def use_database(self, database: str, /) -> "AsyncRecordCursor":
"""Use a database in cursor."""

self._sf_cursor.use_database(database)

return self

def use_schema(self, schema: str, /) -> "AsyncRecordCursor":
"""Use a schema in cursor."""

self._sf_cursor.use_schema(schema)

return self

def use_role(self, role: str, /) -> "AsyncRecordCursor":
"""Use a role in cursor."""

self._sf_cursor.use_role(role)

return self

@property
def _sf_cursor(
def __sf_cursor(
self,
) -> turu.snowflake.async_cursor.AsyncCursor[
GenericRowType, GenericPandasDataFrame, GenericPyArrowTable
]:
return cast(turu.snowflake.async_cursor.AsyncCursor, self._cursor)
return cast(
turu.snowflake.async_cursor.AsyncCursor, self.__record_taregt_cursor
)
41 changes: 6 additions & 35 deletions turu-snowflake/src/turu/snowflake/record/record_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
GenericPandasDataFrame,
GenericPyArrowTable,
)
from typing_extensions import Self


class RecordCursor( # type: ignore[override]
turu.core.record.RecordCursor,
Generic[GenericRowType, GenericPandasDataFrame, GenericPyArrowTable],
):
def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
df = self._sf_cursor.fetch_pandas_all(**kwargs)
df = self.__sf_cursor.fetch_pandas_all(**kwargs)

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -30,7 +29,7 @@ def fetch_pandas_all(self, **kwargs) -> GenericPandasDataFrame:
return df

def fetch_pandas_batches(self, **kwargs) -> Iterator[GenericPandasDataFrame]:
batches = self._sf_cursor.fetch_pandas_batches(**kwargs)
batches = self.__sf_cursor.fetch_pandas_batches(**kwargs)

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -44,7 +43,7 @@ def fetch_pandas_batches(self, **kwargs) -> Iterator[GenericPandasDataFrame]:
return batches

def fetch_arrow_all(self) -> GenericPyArrowTable:
table = self._sf_cursor.fetch_arrow_all()
table = self.__sf_cursor.fetch_arrow_all()

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -59,7 +58,7 @@ def fetch_arrow_all(self) -> GenericPyArrowTable:
return table

def fetch_arrow_batches(self) -> Iterator[GenericPyArrowTable]:
batches = self._sf_cursor.fetch_arrow_batches()
batches = self.__sf_cursor.fetch_arrow_batches()

if isinstance(self._recorder, turu.core.record.CsvRecorder):
if limit := self._recorder._options.get("limit"):
Expand All @@ -72,38 +71,10 @@ def fetch_arrow_batches(self) -> Iterator[GenericPyArrowTable]:

return batches

def use_warehouse(self, warehouse: str, /) -> Self:
"""Use a warehouse in cursor."""

self._sf_cursor.use_warehouse(warehouse)

return self

def use_database(self, database: str, /) -> Self:
"""Use a database in cursor."""

self._sf_cursor.use_database(database)

return self

def use_schema(self, schema: str, /) -> Self:
"""Use a schema in cursor."""

self._sf_cursor.use_schema(schema)

return self

def use_role(self, role: str, /) -> Self:
"""Use a role in cursor."""

self._sf_cursor.use_role(role)

return self

@property
def _sf_cursor(
def __sf_cursor(
self,
) -> turu.snowflake.cursor.Cursor[
GenericRowType, GenericPandasDataFrame, GenericPyArrowTable
]:
return cast(turu.snowflake.cursor.Cursor, self._cursor)
return cast(turu.snowflake.cursor.Cursor, self.__record_taregt_cursor)

0 comments on commit ca4c05e

Please sign in to comment.