diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index c1d5f6c76..368f052dd 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -2,8 +2,10 @@ import importlib -from importlib.metadata import version +import urllib.parse +from importlib.metadata import version +from pathlib import Path from typing import Any, Literal, TYPE_CHECKING, overload, Generic, TypeVar from .connectorx import ( @@ -488,15 +490,14 @@ def try_import_module(name: str): raise ValueError(f"You need to install {name.split('.')[0]} first") -_BackendWithoutSqliteT = TypeVar( - "_BackendWithoutSqliteT", +_ServerBackendT = TypeVar( + "_ServerBackendT", bound=Literal[ "postgres", "postgresql", "mysql", "mssql", "oracle", - "bigquery", "duckdb", ], ) @@ -517,33 +518,46 @@ def __new__( def __new__( cls, *, - backend: _BackendWithoutSqliteT, + backend: Literal["bigquery"], + db_path: str | Path, + ) -> Connection[Literal["bigquery"]]: ... + + @overload + def __new__( + cls, + *, + backend: _ServerBackendT, username: str, password: str = "", server: str, port: int, database: str = "", - ) -> Connection[_BackendWithoutSqliteT]: ... + database_options: dict[str, str] | None = None, + ) -> Connection[_ServerBackendT]: ... def __new__( cls, *, - backend: Literal["sqlite"] | _BackendWithoutSqliteT, + backend: str, username: str = "", password: str = "", server: str = "", port: int | None = None, database: str = "", - db_path: str = "", - ) -> Connection[Literal["sqlite"]] | Connection[_BackendWithoutSqliteT]: + database_options: dict[str, str] | None = None, + db_path: str | Path = "", + ) -> Connection: self = super().__new__(cls) if backend == "sqlite": + db_path = urllib.parse.quote(str(db_path)) self.connection = f"{backend}://{db_path}" else: self.connection = ( f"{backend}://{username}:{password}@{server}:{port}/{database}" ) - return self # type: ignore + if database_options: + self.connection += "?" + urllib.parse.urlencode(database_options) + return self def __str__(self) -> str: return self.connection