Skip to content

Commit

Permalink
Merge pull request #612 from zen-xu/fix-read-sql-annotation
Browse files Browse the repository at this point in the history
fix: correct read_sql return type annotation
  • Loading branch information
wangxiaoying authored Apr 19, 2024
2 parents c5d79ad + 7158e80 commit 832cc84
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
13 changes: 7 additions & 6 deletions connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations


import importlib
from importlib.metadata import version

from typing import Any, Literal, TYPE_CHECKING, overload
from typing import Literal, TYPE_CHECKING, overload

from .connectorx import (
read_sql as _read_sql,
Expand All @@ -20,6 +19,10 @@
import dask.dataframe as dd
import pyarrow as pa

# only for typing hints
from .connectorx import _DataframeInfos, _ArrowInfos


__version__ = version(__name__)

import os
Expand Down Expand Up @@ -394,9 +397,7 @@ def read_sql(
return df


def reconstruct_arrow(
result: tuple[list[str], list[list[tuple[int, int]]]],
) -> pa.Table:
def reconstruct_arrow(result: _ArrowInfos) -> pa.Table:
import pyarrow as pa

names, ptrs = result
Expand All @@ -412,7 +413,7 @@ def reconstruct_arrow(
return pa.Table.from_batches(rbs)


def reconstruct_pandas(df_infos: dict[str, Any]) -> pd.DataFrame:
def reconstruct_pandas(df_infos: _DataframeInfos) -> pd.DataFrame:
import pandas as pd

data = df_infos["data"]
Expand Down
27 changes: 18 additions & 9 deletions connectorx-python/connectorx/connectorx.pyi
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from __future__ import annotations

from typing import overload, Literal, Any, TypeAlias
import pandas as pd
from typing import overload, Literal, Any, TypeAlias, TypedDict
import numpy as np

_ArrowArrayPtr: TypeAlias = int
_ArrowSchemaPtr: TypeAlias = int
_Column: TypeAlias = str
_Header: TypeAlias = str

class PandasBlockInfo:
cids: list[int]
dt: int

class _DataframeInfos(TypedDict):
data: list[tuple[np.ndarray, ...] | np.ndarray]
headers: list[_Header]
block_infos: list[PandasBlockInfo]

_ArrowInfos = tuple[list[_Header], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]

@overload
def read_sql(
Expand All @@ -14,21 +25,19 @@ def read_sql(
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
) -> pd.DataFrame: ...
) -> _DataframeInfos: ...
@overload
def read_sql(
conn: str,
return_type: Literal["arrow", "arrow2"],
protocol: str | None,
queries: list[str] | None,
partition_query: dict[str, Any] | None,
) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ...
) -> _ArrowInfos: ...
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
def read_sql2(
sql: str, db_map: dict[str, str]
) -> tuple[list[_Column], list[list[tuple[_ArrowArrayPtr, _ArrowSchemaPtr]]]]: ...
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...
def get_meta(
conn: str,
protocol: Literal["csv", "binary", "cursor", "simple", "text"] | None,
query: str,
) -> dict[str, Any]: ...
) -> _DataframeInfos: ...

0 comments on commit 832cc84

Please sign in to comment.