Skip to content

Commit

Permalink
refactor(python): Replace copy/paste import handling with `import_opt…
Browse files Browse the repository at this point in the history
…ional` utility function (#15906)
  • Loading branch information
alexander-beedie authored Apr 26, 2024
1 parent e541e96 commit 131354c
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 132 deletions.
19 changes: 5 additions & 14 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3033,19 +3033,10 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
return catalog, schema, tbl # type: ignore[return-value]

if engine == "adbc":
try:
import adbc_driver_manager

adbc_version = parse_version(
getattr(adbc_driver_manager, "__version__", "0.0")
)
except ModuleNotFoundError as exc:
msg = (
"adbc_driver_manager not found"
"\n\nInstall Polars with: pip install adbc_driver_manager"
)
raise ModuleNotFoundError(msg) from exc

adbc_driver_manager = import_optional("adbc_driver_manager")
adbc_version = parse_version(
getattr(adbc_driver_manager, "__version__", "0.0")
)
from polars.io.database._utils import _open_adbc_connection

if if_table_exists == "fail":
Expand Down Expand Up @@ -3110,7 +3101,7 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
try:
from sqlalchemy import create_engine
except ModuleNotFoundError as exc:
msg = "sqlalchemy not found\n\nInstall with: pip install polars[sqlalchemy]"
msg = "'sqlalchemy' not found\n\nInstall with: pip install polars[sqlalchemy]"
raise ModuleNotFoundError(msg) from exc

# note: the catalog (database) should be a part of the connection string
Expand Down
30 changes: 22 additions & 8 deletions py-polars/polars/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,10 @@ def _check_for_pydantic(obj: Any, *, check_type: bool = True) -> bool:

def import_optional(
module_name: str,
err_prefix: str = "Required package",
err_suffix: str = "not installed",
err_prefix: str = "required package",
err_suffix: str = "not found",
min_version: str | tuple[int, ...] | None = None,
install_message: str | None = None,
) -> Any:
"""
Import an optional dependency, returning the module.
Expand All @@ -243,26 +244,39 @@ def import_optional(
Error suffix to use in the raised exception (follows the module name).
min_version : {str, tuple[int]}, optional
If a minimum module version is required, specify it here.
install_message : str, optional
Override the standard "Please install it using..." exception message fragment.
Examples
--------
>>> from polars.dependencies import import_optional
>>> import_optional(
... "definitely_a_real_module",
... err_prefix="super-important package",
... ) # doctest: +SKIP
ImportError: super-important package 'definitely_a_real_module' not installed.
Please install it using the command `pip install definitely_a_real_module`.
"""
from polars._utils.various import parse_version
from polars.exceptions import ModuleUpgradeRequired

module_root = module_name.split(".", 1)[0]
try:
module = import_module(module_name)
except ImportError:
prefix = f"{err_prefix.strip(' ')} " if err_prefix else ""
suffix = f" {err_prefix.strip(' ')}" if err_suffix else ""
err_message = (
f"{prefix}'{module_name}'{suffix}.\n"
f"Please install it using the command `pip install {module_name}`."
suffix = f" {err_suffix.strip(' ')}" if err_suffix else ""
err_message = f"{prefix}'{module_name}'{suffix}.\n" + (
install_message
or f"Please install it using the command `pip install {module_root}`."
)
raise ImportError(err_message) from None
raise ModuleNotFoundError(err_message) from None

if min_version:
min_version = parse_version(min_version)
mod_version = parse_version(module.__version__)
if mod_version < min_version:
msg = f"requires module_name {min_version} or higher, found {mod_version}"
msg = f"requires {module_root} {min_version} or higher; found {mod_version}"
raise ModuleUpgradeRequired(msg)

return module
Expand Down
49 changes: 17 additions & 32 deletions py-polars/polars/io/database/_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import re
import sys
from importlib import import_module
from typing import TYPE_CHECKING, Any

from polars.convert import from_arrow
from polars.dependencies import import_optional

if TYPE_CHECKING:
import sys
from collections.abc import Coroutine

if sys.version_info >= (3, 10):
Expand All @@ -29,22 +29,14 @@ def _run_async(co: Coroutine[Any, Any, Any]) -> Any:
import asyncio

from polars._utils.unstable import issue_unstable_warning
from polars.dependencies import import_optional

issue_unstable_warning(
"Use of asynchronous connections is currently considered unstable"
" and unexpected issues may arise; if this happens, please report them."
"Use of asynchronous connections is currently considered unstable "
"and unexpected issues may arise; if this happens, please report them."
)
try:
import nest_asyncio

nest_asyncio.apply()
except ModuleNotFoundError as _err:
msg = (
"Executing using async drivers requires the `nest_asyncio` package."
"\n\nPlease run: pip install nest_asyncio"
)
raise ModuleNotFoundError(msg) from None

nest_asyncio = import_optional("nest_asyncio")
nest_asyncio.apply()
return asyncio.run(co)


Expand All @@ -57,12 +49,7 @@ def _read_sql_connectorx(
protocol: str | None = None,
schema_overrides: SchemaDict | None = None,
) -> DataFrame:
try:
import connectorx as cx
except ModuleNotFoundError:
msg = "connectorx is not installed" "\n\nPlease run: pip install connectorx"
raise ModuleNotFoundError(msg) from None

cx = import_optional("connectorx")
try:
tbl = cx.read_sql(
conn=connection_uri,
Expand Down Expand Up @@ -100,17 +87,15 @@ def _open_adbc_connection(connection_uri: str) -> Any:
module_suffix_map: dict[str, str] = {
"postgres": "postgresql",
}
try:
module_suffix = module_suffix_map.get(driver_name, driver_name)
module_name = f"adbc_driver_{module_suffix}.dbapi"
import_module(module_name)
adbc_driver = sys.modules[module_name]
except ImportError:
msg = (
f"ADBC {driver_name} driver not detected"
f"\n\nIf ADBC supports this database, please run: pip install adbc-driver-{driver_name} pyarrow"
)
raise ModuleNotFoundError(msg) from None
module_suffix = module_suffix_map.get(driver_name, driver_name)
module_name = f"adbc_driver_{module_suffix}.dbapi"

adbc_driver = import_optional(
module_name,
err_prefix="ADBC",
err_suffix="driver not detected",
install_message=f"If ADBC supports this database, please run: pip install adbc-driver-{driver_name} pyarrow",
)

# some backends require the driver name to be stripped from the URI
if driver_name in ("sqlite", "snowflake"):
Expand Down
15 changes: 6 additions & 9 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from polars._utils.deprecation import issue_deprecation_warning
from polars.datatypes import N_INFER_DEFAULT
from polars.dependencies import import_optional
from polars.exceptions import InvalidOperationError
from polars.io.database._cursor_proxies import ODBCCursorProxy
from polars.io.database._executor import ConnectionExecutor
Expand Down Expand Up @@ -224,15 +225,11 @@ def read_database( # noqa: D417
if isinstance(connection, str):
# check for odbc connection string
if re.search(r"\bdriver\s*=\s*{[^}]+?}", connection, re.IGNORECASE):
try:
import arrow_odbc # noqa: F401
except ModuleNotFoundError:
msg = (
"use of an ODBC connection string requires the `arrow-odbc` package"
"\n\nPlease run: pip install arrow-odbc"
)
raise ModuleNotFoundError(msg) from None

_ = import_optional(
module_name="arrow_odbc",
err_prefix="use of ODBC connection string requires the",
err_suffix="package",
)
connection = ODBCCursorProxy(connection)
else:
# otherwise looks like a call to read_database_uri
Expand Down
36 changes: 17 additions & 19 deletions py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
normalize_filepath,
)
from polars._utils.wrap import wrap_df, wrap_ldf
from polars.dependencies import _PYARROW_AVAILABLE
from polars.dependencies import import_optional
from polars.io._utils import (
is_glob_pattern,
is_local_file,
Expand Down Expand Up @@ -98,14 +98,16 @@ def read_ipc(
source, use_pyarrow=use_pyarrow, storage_options=storage_options
) as data:
if use_pyarrow:
if not _PYARROW_AVAILABLE:
msg = "pyarrow is required when using `read_ipc(..., use_pyarrow=True)`"
raise ModuleNotFoundError(msg)

import pyarrow as pa
import pyarrow.feather

tbl = pa.feather.read_table(data, memory_map=memory_map, columns=columns)
pyarrow_feather = import_optional(
"pyarrow.feather",
err_prefix="",
err_suffix="is required when using 'read_ipc(..., use_pyarrow=True)'",
)
tbl = pyarrow_feather.read_table(
data,
memory_map=memory_map,
columns=columns,
)
df = pl.DataFrame._from_arrow(tbl, rechunk=rechunk)
if row_index_name is not None:
df = df.with_row_index(row_index_name, row_index_offset)
Expand Down Expand Up @@ -225,16 +227,12 @@ def read_ipc_stream(
source, use_pyarrow=use_pyarrow, storage_options=storage_options
) as data:
if use_pyarrow:
if not _PYARROW_AVAILABLE:
msg = (
"'pyarrow' is required when using"
" 'read_ipc_stream(..., use_pyarrow=True)'"
)
raise ModuleNotFoundError(msg)

import pyarrow as pa

with pa.ipc.RecordBatchStreamReader(data) as reader:
pyarrow_ipc = import_optional(
"pyarrow.ipc",
err_prefix="",
err_suffix="is required when using 'read_ipc_stream(..., use_pyarrow=True)'",
)
with pyarrow_ipc.RecordBatchStreamReader(data) as reader:
tbl = reader.read_all()
df = pl.DataFrame._from_arrow(tbl, rechunk=rechunk)
if row_index_name is not None:
Expand Down
16 changes: 7 additions & 9 deletions py-polars/polars/io/parquet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from polars._utils.wrap import wrap_df, wrap_ldf
from polars.convert import from_arrow
from polars.dependencies import _PYARROW_AVAILABLE
from polars.dependencies import import_optional
from polars.io._utils import (
is_local_file,
is_supported_cloud,
Expand Down Expand Up @@ -209,21 +209,19 @@ def _read_parquet_with_pyarrow(
pyarrow_options: dict[str, Any] | None = None,
memory_map: bool = True,
) -> DataFrame:
if not _PYARROW_AVAILABLE:
msg = "'pyarrow' is required when using `read_parquet(..., use_pyarrow=True)`"
raise ModuleNotFoundError(msg)

import pyarrow as pa
import pyarrow.parquet

pyarrow_parquet = import_optional(
"pyarrow.parquet",
err_prefix="",
err_suffix="is required when using `read_parquet(..., use_pyarrow=True)`",
)
pyarrow_options = pyarrow_options or {}

with prepare_file_arg(
source, # type: ignore[arg-type]
use_pyarrow=True,
storage_options=storage_options,
) as source_prep:
pa_table = pa.parquet.read_table(
pa_table = pyarrow_parquet.read_table(
source_prep,
memory_map=memory_map,
columns=columns,
Expand Down
Loading

0 comments on commit 131354c

Please sign in to comment.