From 00cdb0d2b0e0e9dba11d608cddb7d1081155da8b Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Mon, 28 Oct 2024 19:43:46 +0700 Subject: [PATCH] Finish SQL functions refactoring --- examples/computer_vision/openimage-detect.py | 4 +- examples/get_started/common_sql_functions.py | 9 +- examples/multimodal/clip_inference.py | 2 +- examples/multimodal/wds.py | 2 +- examples/multimodal/wds_filtered.py | 15 +- src/datachain/catalog/catalog.py | 14 +- src/datachain/cli.py | 2 +- src/datachain/client/fsspec.py | 19 +- src/datachain/data_storage/schema.py | 2 +- src/datachain/data_storage/warehouse.py | 2 +- src/datachain/lib/dc.py | 71 ++++--- src/datachain/lib/func/__init__.py | 16 ++ src/datachain/lib/func/aggregate.py | 25 +-- src/datachain/lib/func/array.py | 175 ++++++++++++++++ src/datachain/lib/func/conditional.py | 80 ++++++++ src/datachain/lib/func/func.py | 189 ++++++++++++++++-- src/datachain/lib/func/inner/__init__.py | 0 .../functions => lib/func/inner}/aggregate.py | 0 .../functions => lib/func/inner}/array.py | 0 src/datachain/lib/func/inner/base.py | 20 ++ .../func/inner}/conditional.py | 0 .../{sql/functions => lib/func/inner}/path.py | 0 .../functions => lib/func/inner}/random.py | 0 .../functions => lib/func/inner}/string.py | 0 src/datachain/lib/func/path.py | 109 ++++++++++ src/datachain/lib/func/random.py | 22 ++ src/datachain/lib/func/string.py | 153 ++++++++++++++ src/datachain/lib/listing.py | 2 +- src/datachain/listing.py | 2 +- src/datachain/nodes_fetcher.py | 4 +- src/datachain/query/dataset.py | 50 +++-- src/datachain/sql/__init__.py | 2 - src/datachain/sql/functions/__init__.py | 26 --- src/datachain/sql/functions2/__init__.py | 16 ++ src/datachain/sql/functions2/aggregate.py | 47 +++++ src/datachain/sql/functions2/array.py | 50 +++++ src/datachain/sql/functions2/base.py | 18 ++ src/datachain/sql/functions2/conditional.py | 9 + src/datachain/sql/functions2/path.py | 61 ++++++ src/datachain/sql/functions2/random.py | 12 ++ src/datachain/sql/functions2/string.py | 54 +++++ src/datachain/sql/selectable.py | 18 +- src/datachain/sql/sqlite/base.py | 4 +- tests/func/test_datachain.py | 11 +- tests/func/test_dataset_query.py | 2 +- tests/func/test_datasets.py | 2 +- tests/func/test_pull.py | 2 +- tests/unit/lib/test_datachain.py | 18 +- tests/unit/lib/test_sql_to_python.py | 3 - tests/unit/sql/test_array.py | 65 +++++- tests/unit/sql/test_conditional.py | 35 +++- tests/unit/sql/test_path.py | 21 +- tests/unit/sql/test_random.py | 4 +- tests/unit/sql/test_string.py | 4 +- tests/unit/test_session.py | 3 +- 55 files changed, 1268 insertions(+), 208 deletions(-) create mode 100644 src/datachain/lib/func/array.py create mode 100644 src/datachain/lib/func/conditional.py create mode 100644 src/datachain/lib/func/inner/__init__.py rename src/datachain/{sql/functions => lib/func/inner}/aggregate.py (100%) rename src/datachain/{sql/functions => lib/func/inner}/array.py (100%) create mode 100644 src/datachain/lib/func/inner/base.py rename src/datachain/{sql/functions => lib/func/inner}/conditional.py (100%) rename src/datachain/{sql/functions => lib/func/inner}/path.py (100%) rename src/datachain/{sql/functions => lib/func/inner}/random.py (100%) rename src/datachain/{sql/functions => lib/func/inner}/string.py (100%) create mode 100644 src/datachain/lib/func/path.py create mode 100644 src/datachain/lib/func/random.py create mode 100644 src/datachain/lib/func/string.py delete mode 100644 src/datachain/sql/functions/__init__.py create mode 100644 src/datachain/sql/functions2/__init__.py create mode 100644 src/datachain/sql/functions2/aggregate.py create mode 100644 src/datachain/sql/functions2/array.py create mode 100644 src/datachain/sql/functions2/base.py create mode 100644 src/datachain/sql/functions2/conditional.py create mode 100644 src/datachain/sql/functions2/path.py create mode 100644 src/datachain/sql/functions2/random.py create mode 100644 src/datachain/sql/functions2/string.py diff --git a/examples/computer_vision/openimage-detect.py b/examples/computer_vision/openimage-detect.py index 066641fd0..1e28a9910 100644 --- a/examples/computer_vision/openimage-detect.py +++ b/examples/computer_vision/openimage-detect.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from datachain import C, DataChain, File -from datachain.sql.functions import path +from datachain.lib.func import path class BBox(BaseModel): @@ -54,7 +54,7 @@ def openimage_detect(args): .filter(C("file.path").glob("*.jpg") | C("file.path").glob("*.json")) .agg( openimage_detect, - partition_by=path.file_stem(C("file.path")), + partition_by=path.file_stem("file.path"), params=["file"], output={"file": File, "bbox": BBox}, ) diff --git a/examples/get_started/common_sql_functions.py b/examples/get_started/common_sql_functions.py index bb96f1f99..e077937a5 100644 --- a/examples/get_started/common_sql_functions.py +++ b/examples/get_started/common_sql_functions.py @@ -1,6 +1,5 @@ from datachain import C, DataChain -from datachain.sql import literal -from datachain.sql.functions import array, greatest, least, path, string +from datachain.lib.func import array, greatest, least, path, string def num_chars_udf(file): @@ -18,7 +17,7 @@ def num_chars_udf(file): ( dc.mutate( length=string.length(path.name(C("file.path"))), - parts=string.split(path.name(C("file.path")), literal(".")), + parts=string.split(path.name(C("file.path")), "."), ) .select("file.path", "length", "parts") .show(5) @@ -35,8 +34,8 @@ def num_chars_udf(file): chain = dc.mutate( - a=array.length(string.split(C("file.path"), literal("/"))), - b=array.length(string.split(path.name(C("file.path")), literal("0"))), + a=array.length(string.split("file.path", "/")), + b=array.length(string.split(path.name("file.path"), "0")), ) ( diff --git a/examples/multimodal/clip_inference.py b/examples/multimodal/clip_inference.py index b6a37dcf7..4fdd1b490 100644 --- a/examples/multimodal/clip_inference.py +++ b/examples/multimodal/clip_inference.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader from datachain import C, DataChain -from datachain.sql.functions import path +from datachain.lib.func import path source = "gs://datachain-demo/50k-laion-files/000000/00000000*" diff --git a/examples/multimodal/wds.py b/examples/multimodal/wds.py index 6d016dbc6..ec5c285ad 100644 --- a/examples/multimodal/wds.py +++ b/examples/multimodal/wds.py @@ -1,9 +1,9 @@ import os from datachain import DataChain +from datachain.lib.func import path from datachain.lib.webdataset import process_webdataset from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta -from datachain.sql.functions import path IMAGE_TARS = os.getenv( "IMAGE_TARS", "gs://datachain-demo/datacomp-small/shards/000000[0-5]*.tar" diff --git a/examples/multimodal/wds_filtered.py b/examples/multimodal/wds_filtered.py index a06b27657..e5d619647 100644 --- a/examples/multimodal/wds_filtered.py +++ b/examples/multimodal/wds_filtered.py @@ -1,9 +1,8 @@ import datachain.error from datachain import C, DataChain +from datachain.lib.func import array, greatest, least, string from datachain.lib.webdataset import process_webdataset from datachain.lib.webdataset_laion import WDSLaion -from datachain.sql import literal -from datachain.sql.functions import array, greatest, least, string name = "wds" try: @@ -20,14 +19,12 @@ wds.print_schema() filtered = ( - wds.filter(string.length(C("laion.txt")) > 5) - .filter(array.length(string.split(C("laion.txt"), literal(" "))) > 2) + wds.filter(string.length("laion.txt") > 5) + .filter(array.length(string.split("laion.txt", " ")) > 2) + .filter(least("laion.json.original_width", "laion.json.original_height") > 200) .filter( - least(C("laion.json.original_width"), C("laion.json.original_height")) > 200 - ) - .filter( - greatest(C("laion.json.original_width"), C("laion.json.original_height")) - / least(C("laion.json.original_width"), C("laion.json.original_height")) + greatest("laion.json.original_width", "laion.json.original_height") + / least("laion.json.original_width", "laion.json.original_height") < 3.0 ) .save() diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 80776df74..a717c13b1 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -53,7 +53,6 @@ QueryScriptCancelError, QueryScriptRunError, ) -from datachain.listing import Listing from datachain.node import DirType, Node, NodeWithPath from datachain.nodes_thread_pool import NodesThreadPool from datachain.remote.studio import StudioClient @@ -76,6 +75,7 @@ from datachain.dataset import DatasetVersion from datachain.job import Job from datachain.lib.file import File + from datachain.listing import Listing logger = logging.getLogger("datachain") @@ -241,7 +241,7 @@ def do_task(self, urls): class NodeGroup: """Class for a group of nodes from the same source""" - listing: Listing + listing: "Listing" sources: list[DataSource] # The source path within the bucket @@ -596,8 +596,9 @@ def enlist_source( client_config=None, object_name="file", skip_indexing=False, - ) -> tuple[Listing, str]: + ) -> tuple["Listing", str]: from datachain.lib.dc import DataChain + from datachain.listing import Listing DataChain.from_storage( source, session=self.session, update=update, object_name=object_name @@ -664,7 +665,8 @@ def enlist_sources_grouped( no_glob: bool = False, client_config=None, ) -> list[NodeGroup]: - from datachain.query import DatasetQuery + from datachain.listing import Listing + from datachain.query.dataset import DatasetQuery def _row_to_node(d: dict[str, Any]) -> Node: del d["file__source"] @@ -872,7 +874,7 @@ def create_new_dataset_version( def update_dataset_version_with_warehouse_info( self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs ) -> None: - from datachain.query import DatasetQuery + from datachain.query.dataset import DatasetQuery dataset_version = dataset.get_version(version) @@ -1173,7 +1175,7 @@ def listings(self): def ls_dataset_rows( self, name: str, version: int, offset=None, limit=None ) -> list[dict]: - from datachain.query import DatasetQuery + from datachain.query.dataset import DatasetQuery dataset = self.get_dataset(name) diff --git a/src/datachain/cli.py b/src/datachain/cli.py index 036952bce..72c23aaa1 100644 --- a/src/datachain/cli.py +++ b/src/datachain/cli.py @@ -872,7 +872,7 @@ def show( schema: bool = False, ) -> None: from datachain.lib.dc import DataChain - from datachain.query import DatasetQuery + from datachain.query.dataset import DatasetQuery from datachain.utils import show_records dataset = catalog.get_dataset(name) diff --git a/src/datachain/client/fsspec.py b/src/datachain/client/fsspec.py index f1f7b4090..2518dfa74 100644 --- a/src/datachain/client/fsspec.py +++ b/src/datachain/client/fsspec.py @@ -28,7 +28,6 @@ from datachain.cache import DataChainCache from datachain.client.fileslice import FileWrapper from datachain.error import ClientError as DataChainClientError -from datachain.lib.file import File from datachain.nodes_fetcher import NodesFetcher from datachain.nodes_thread_pool import NodeChunk from datachain.storage import StorageURI @@ -36,6 +35,8 @@ if TYPE_CHECKING: from fsspec.spec import AbstractFileSystem + from datachain.lib.file import File + logger = logging.getLogger("datachain") @@ -44,7 +45,7 @@ DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$") -ResultQueue = asyncio.Queue[Optional[Sequence[File]]] +ResultQueue = asyncio.Queue[Optional[Sequence["File"]]] def _is_win_local_path(uri: str) -> bool: @@ -207,7 +208,7 @@ async def get_file(self, lpath, rpath, callback): async def scandir( self, start_prefix: str, method: str = "default" - ) -> AsyncIterator[Sequence[File]]: + ) -> AsyncIterator[Sequence["File"]]: try: impl = getattr(self, f"_fetch_{method}") except AttributeError: @@ -312,7 +313,7 @@ def get_full_path(self, rel_path: str) -> str: return f"{self.PREFIX}{self.name}/{rel_path}" @abstractmethod - def info_to_file(self, v: dict[str, Any], parent: str) -> File: ... + def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ... def fetch_nodes( self, @@ -349,7 +350,7 @@ def do_instantiate_object(self, file: "File", dst: str) -> None: copy2(src, dst) def open_object( - self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK + self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK ) -> BinaryIO: """Open a file, including files in tar archives.""" if use_cache and (cache_path := self.cache.get_path(file)): @@ -357,19 +358,19 @@ def open_object( assert not file.location return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value] - def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None: + def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None: sync(get_loop(), functools.partial(self._download, file, callback=callback)) - async def _download(self, file: File, *, callback: "Callback" = None) -> None: + async def _download(self, file: "File", *, callback: "Callback" = None) -> None: if self.cache.contains(file): # Already in cache, so there's nothing to do. return await self._put_in_cache(file, callback=callback) - def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None: + def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None: sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback)) - async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None: + async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None: assert not file.location if file.etag: etag = await self.get_current_etag(file) diff --git a/src/datachain/data_storage/schema.py b/src/datachain/data_storage/schema.py index f34bdeeca..5f5817ce3 100644 --- a/src/datachain/data_storage/schema.py +++ b/src/datachain/data_storage/schema.py @@ -12,7 +12,7 @@ from sqlalchemy.sql import func as f from sqlalchemy.sql.expression import false, null, true -from datachain.sql.functions import path +from datachain.lib.func.inner import path from datachain.sql.types import Int, SQLType, UInt64 if TYPE_CHECKING: diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 8acc870f6..01ecf36b0 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -20,8 +20,8 @@ from datachain.data_storage.schema import convert_rows_custom_column_types from datachain.data_storage.serializer import Serializable from datachain.dataset import DatasetRecord +from datachain.lib.func.inner import path as pathfunc from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path -from datachain.sql.functions import path as pathfunc from datachain.sql.types import Int, SQLType from datachain.storage import StorageURI from datachain.utils import sql_escape_like diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index ba5129b21..e005542cf 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -35,6 +35,8 @@ from datachain.lib.file import ArrowRow, File, get_file_type from datachain.lib.file import ExportPlacement as FileExportPlacement from datachain.lib.func import Func +from datachain.lib.func.inner import path as pathfunc +from datachain.lib.func.inner.base import Function from datachain.lib.listing import ( list_bucket, ls, @@ -51,7 +53,6 @@ from datachain.query import Session from datachain.query.dataset import DatasetQuery, PartitionByType from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta -from datachain.sql.functions import path as pathfunc from datachain.telemetry import telemetry from datachain.utils import batched_it, inside_notebook, row_to_nested_dict @@ -112,9 +113,25 @@ def __init__(self, name, msg): # noqa: D107 super().__init__(f"Dataset{name} from values error: {msg}") -def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str: +MergeColType = Union[str, Function, sqlalchemy.ColumnElement] + + +def _validate_merge_on( + on: Union[MergeColType, Sequence[MergeColType]], +) -> Sequence[MergeColType]: + if isinstance(on, (str, sqlalchemy.ColumnElement)): + return [on] + if isinstance(on, Function): + return [on.get_column()] + if isinstance(on, Sequence): + return [c.get_column() if isinstance(c, Function) else c for c in on] + + +def _get_merge_error_str(col: MergeColType) -> str: if isinstance(col, str): return col + if isinstance(col, Function): + return f"{col.name}()" if isinstance(col, sqlalchemy.Column): return col.name.replace(DEFAULT_DELIMITER, ".") if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"): @@ -125,11 +142,13 @@ def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str: class DatasetMergeError(DataChainParamsError): # noqa: D101 def __init__( # noqa: D107 self, - on: Sequence[Union[str, sqlalchemy.ColumnElement]], - right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]], + on: Union[MergeColType, Sequence[MergeColType]], + right_on: Optional[Union[MergeColType, Sequence[MergeColType]]], msg: str, ): - def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str: + def _get_str( + on: Union[MergeColType, Sequence[MergeColType]], + ) -> str: if not isinstance(on, Sequence): return str(on) # type: ignore[unreachable] return ", ".join([_get_merge_error_str(col) for col in on]) @@ -1322,25 +1341,16 @@ def remove_file_signals(self) -> "Self": # noqa: D102 def merge( self, right_ds: "DataChain", - on: Union[ - str, - sqlalchemy.ColumnElement, - Sequence[Union[str, sqlalchemy.ColumnElement]], - ], - right_on: Union[ - str, - sqlalchemy.ColumnElement, - Sequence[Union[str, sqlalchemy.ColumnElement]], - None, - ] = None, + on: Union[MergeColType, Sequence[MergeColType]], + right_on: Optional[Union[MergeColType, Sequence[MergeColType]]] = None, inner=False, rname="right_", ) -> "Self": """Merge two chains based on the specified criteria. Parameters: - right_ds : Chain to join with. - on : Predicate or list of Predicates to join on. If both chains have the + right_ds: Chain to join with. + on: Predicate or list of Predicates to join on. If both chains have the same predicates then this predicate is enough for the join. Otherwise, `right_on` parameter has to specify the predicates for the other chain. right_on: Optional predicate or list of Predicates @@ -1357,23 +1367,24 @@ def merge( if on is None: raise DatasetMergeError(["None"], None, "'on' must be specified") - if isinstance(on, (str, sqlalchemy.ColumnElement)): - on = [on] - elif not isinstance(on, Sequence): + on = _validate_merge_on(on) + if not on: raise DatasetMergeError( on, right_on, - f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'", + ( + "'on' must be 'str', 'Func' or 'Sequence' object " + f"but got type '{type(on)}'" + ), ) if right_on is not None: - if isinstance(right_on, (str, sqlalchemy.ColumnElement)): - right_on = [right_on] - elif not isinstance(right_on, Sequence): + right_on = _validate_merge_on(right_on) + if not right_on: raise DatasetMergeError( on, right_on, - "'right_on' must be 'str' or 'Sequence' object" + "'right_on' must be 'str', 'Func' or 'Sequence' object" f" but got type '{type(right_on)}'", ) @@ -1389,10 +1400,12 @@ def merge( def _resolve( ds: DataChain, - col: Union[str, sqlalchemy.ColumnElement], + col: Union[str, Function, sqlalchemy.ColumnElement], side: Union[str, None], ): try: + if isinstance(col, Function): + return ds.c(col.get_column()) return ds.c(col) if isinstance(col, (str, C)) else col except ValueError: if side: @@ -2298,9 +2311,9 @@ def filter(self, *args: Any) -> "Self": dc.filter(C("file.name").glob("*.jpg")) ``` - Using `datachain.sql.functions` + Using `datachain.lib.func` ```py - from datachain.sql.functions import string + from datachain.lib.func import string dc.filter(string.length(C("file.name")) > 5) ``` diff --git a/src/datachain/lib/func/__init__.py b/src/datachain/lib/func/__init__.py index ba6f08027..3d2a51ad0 100644 --- a/src/datachain/lib/func/__init__.py +++ b/src/datachain/lib/func/__init__.py @@ -1,3 +1,6 @@ +from sqlalchemy import literal + +from . import array, path, random, string from .aggregate import ( any_value, avg, @@ -12,21 +15,34 @@ row_number, sum, ) +from .array import cosine_distance, euclidean_distance, length, sip_hash_64 +from .conditional import greatest, least from .func import Func, window __all__ = [ "Func", "any_value", + "array", "avg", "collect", "concat", + "cosine_distance", "count", "dense_rank", + "euclidean_distance", "first", + "greatest", + "least", + "length", + "literal", "max", "min", + "path", + "random", "rank", "row_number", + "sip_hash_64", + "string", "sum", "window", ] diff --git a/src/datachain/lib/func/aggregate.py b/src/datachain/lib/func/aggregate.py index 00ae0077a..7d2c75180 100644 --- a/src/datachain/lib/func/aggregate.py +++ b/src/datachain/lib/func/aggregate.py @@ -2,9 +2,8 @@ from sqlalchemy import func as sa_func -from datachain.sql import functions as dc_func - from .func import Func +from .inner import aggregate def count(col: Optional[str] = None) -> Func: @@ -31,7 +30,9 @@ def count(col: Optional[str] = None) -> Func: Notes: - Result column will always be of type int. """ - return Func("count", inner=sa_func.count, col=col, result_type=int) + return Func( + "count", inner=sa_func.count, cols=[col] if col else None, result_type=int + ) def sum(col: str) -> Func: @@ -59,7 +60,7 @@ def sum(col: str) -> Func: - The `sum` function should be used on numeric columns. - Result column type will be the same as the input column type. """ - return Func("sum", inner=sa_func.sum, col=col) + return Func("sum", inner=sa_func.sum, cols=[col]) def avg(col: str) -> Func: @@ -87,7 +88,7 @@ def avg(col: str) -> Func: - The `avg` function should be used on numeric columns. - Result column will always be of type float. """ - return Func("avg", inner=dc_func.aggregate.avg, col=col, result_type=float) + return Func("avg", inner=aggregate.avg, cols=[col], result_type=float) def min(col: str) -> Func: @@ -115,7 +116,7 @@ def min(col: str) -> Func: - The `min` function can be used with numeric, date, and string columns. - Result column will have the same type as the input column. """ - return Func("min", inner=sa_func.min, col=col) + return Func("min", inner=sa_func.min, cols=[col]) def max(col: str) -> Func: @@ -143,7 +144,7 @@ def max(col: str) -> Func: - The `max` function can be used with numeric, date, and string columns. - Result column will have the same type as the input column. """ - return Func("max", inner=sa_func.max, col=col) + return Func("max", inner=sa_func.max, cols=[col]) def any_value(col: str) -> Func: @@ -174,7 +175,7 @@ def any_value(col: str) -> Func: - The result of `any_value` is non-deterministic, meaning it may return different values for different executions. """ - return Func("any_value", inner=dc_func.aggregate.any_value, col=col) + return Func("any_value", inner=aggregate.any_value, cols=[col]) def collect(col: str) -> Func: @@ -203,7 +204,7 @@ def collect(col: str) -> Func: - The `collect` function can be used with numeric and string columns. - Result column will have an array type. """ - return Func("collect", inner=dc_func.aggregate.collect, col=col, is_array=True) + return Func("collect", inner=aggregate.collect, cols=[col], is_array=True) def concat(col: str, separator="") -> Func: @@ -236,9 +237,9 @@ def concat(col: str, separator="") -> Func: """ def inner(arg): - return dc_func.aggregate.group_concat(arg, separator) + return aggregate.group_concat(arg, separator) - return Func("concat", inner=inner, col=col, result_type=str) + return Func("concat", inner=inner, cols=[col], result_type=str) def row_number() -> Func: @@ -350,4 +351,4 @@ def first(col: str) -> Func: in the specified order. - The result column will have the same type as the input column. """ - return Func("first", inner=sa_func.first_value, col=col, is_window=True) + return Func("first", inner=sa_func.first_value, cols=[col], is_window=True) diff --git a/src/datachain/lib/func/array.py b/src/datachain/lib/func/array.py new file mode 100644 index 000000000..c0ac52818 --- /dev/null +++ b/src/datachain/lib/func/array.py @@ -0,0 +1,175 @@ +from collections.abc import Sequence +from typing import Union + +from .func import Func +from .inner import array + + +def cosine_distance(*args: Union[str, Sequence]) -> Func: + """ + Computes the cosine distance between two vectors. + + The cosine distance is derived from the cosine similarity, which measures the angle + between two vectors. This function returns the dissimilarity between the vectors, + where 0 indicates identical vectors and values closer to 1 + indicate higher dissimilarity. + + Args: + args (str | Sequence): Two vectors to compute the cosine distance between. + If a string is provided, it is assumed to be the name of the column vector. + If a sequence is provided, it is assumed to be a vector of values. + + Returns: + Func: A Func object that represents the cosine_distance function. + + Example: + ```py + target_embedding = [0.1, 0.2, 0.3] + dc.mutate( + cos_dist1=func.cosine_distance("embedding", target_embedding), + cos_dist2=func.cosine_distance(target_embedding, [0.4, 0.5, 0.6]), + ) + ``` + + Notes: + - Ensure both vectors have the same number of elements. + - Result column will always be of type float. + """ + cols, func_args = [], [] + for arg in args: + if isinstance(arg, str): + cols.append(arg) + else: + func_args.append(list(arg)) + + if len(cols) + len(func_args) != 2: + raise ValueError("cosine_distance() requires exactly two arguments") + if not cols and len(func_args[0]) != len(func_args[1]): + raise ValueError("cosine_distance() requires vectors of the same length") + + return Func( + "cosine_distance", + inner=array.cosine_distance, + cols=cols, + args=func_args, + result_type=float, + ) + + +def euclidean_distance(*args: Union[str, Sequence]) -> Func: + """ + Computes the Euclidean distance between two vectors. + + The Euclidean distance is the straight-line distance between two points + in Euclidean space. This function returns the distance between the two vectors. + + Args: + args (str | Sequence): Two vectors to compute the Euclidean distance between. + If a string is provided, it is assumed to be the name of the column vector. + If a sequence is provided, it is assumed to be a vector of values. + + Returns: + Func: A Func object that represents the euclidean_distance function. + + Example: + ```py + target_embedding = [0.1, 0.2, 0.3] + dc.mutate( + eu_dist1=func.euclidean_distance("embedding", target_embedding), + eu_dist2=func.euclidean_distance(target_embedding, [0.4, 0.5, 0.6]), + ) + ``` + + Notes: + - Ensure both vectors have the same number of elements. + - Result column will always be of type float. + """ + cols, func_args = [], [] + for arg in args: + if isinstance(arg, str): + cols.append(arg) + else: + func_args.append(list(arg)) + + if len(cols) + len(func_args) != 2: + raise ValueError("euclidean_distance() requires exactly two arguments") + if not cols and len(func_args[0]) != len(func_args[1]): + raise ValueError("euclidean_distance() requires vectors of the same length") + + return Func( + "euclidean_distance", + inner=array.euclidean_distance, + cols=cols, + args=func_args, + result_type=float, + ) + + +def length(arg: Union[str, Sequence, Func]) -> Func: + """ + Returns the length of the array. + + Args: + arg (str | Sequence | Func): Array to compute the length of. + If a string is provided, it is assumed to be the name of the array column. + If a sequence is provided, it is assumed to be an array of values. + If a Func is provided, it is assumed to be a function returning an array. + + Returns: + Func: A Func object that represents the array length function. + + Example: + ```py + dc.mutate( + len1=func.array.length("signal.values"), + len2=func.array.length([1, 2, 3, 4, 5]), + ) + ``` + + Note: + - Result column will always be of type int. + """ + if isinstance(arg, (str, Func)): + cols = [arg] + args = None + else: + cols = None + args = [arg] + + return Func("length", inner=array.length, cols=cols, args=args, result_type=int) + + +def sip_hash_64(arg: Union[str, Sequence]) -> Func: + """ + Computes the SipHash-64 hash of the array. + + Args: + arg (str | Sequence): Array to compute the SipHash-64 hash of. + If a string is provided, it is assumed to be the name of the array column. + If a sequence is provided, it is assumed to be an array of values. + + Returns: + Func: A Func object that represents the sip_hash_64 function. + + Example: + ```py + dc.mutate( + hash1=func.sip_hash_64("signal.values"), + hash2=func.sip_hash_64([1, 2, 3, 4, 5]), + ) + ``` + + Note: + - This function is only available for the ClickHouse warehouse. + - Result column will always be of type int. + """ + if isinstance(arg, str): + cols = [arg] + args = None + else: + cols = None + args = [arg] + + return Func( + "sip_hash_64", inner=array.sip_hash_64, cols=cols, args=args, result_type=int + ) diff --git a/src/datachain/lib/func/conditional.py b/src/datachain/lib/func/conditional.py new file mode 100644 index 000000000..aa8e0b236 --- /dev/null +++ b/src/datachain/lib/func/conditional.py @@ -0,0 +1,80 @@ +from typing import Union + +from .func import ColT, Func +from .inner import conditional + + +def greatest(*args: Union[ColT, float]) -> Func: + """ + Returns the greatest (largest) value from the given input values. + + Args: + args (ColT | str | int | float | Sequence): The values to compare. + If a string is provided, it is assumed to be the name of the column. + If a Func is provided, it is assumed to be a function returning a value. + If an int, float, or Sequence is provided, it is assumed to be a literal. + + Returns: + Func: A Func object that represents the greatest function. + + Example: + ```py + dc.mutate( + greatest=func.greatest("signal.value", 0), + ) + ``` + + Note: + - Result column will always be of the same type as the input columns. + """ + cols, func_args = [], [] + + for arg in args: + if isinstance(arg, (str, Func)): + cols.append(arg) + else: + func_args.append(arg) + + return Func( + "greatest", + inner=conditional.greatest, + cols=cols, + args=func_args, + result_type=int, + ) + + +def least(*args: Union[ColT, float]) -> Func: + """ + Returns the least (smallest) value from the given input values. + + Args: + args (ColT | str | int | float | Sequence): The values to compare. + If a string is provided, it is assumed to be the name of the column. + If a Func is provided, it is assumed to be a function returning a value. + If an int, float, or Sequence is provided, it is assumed to be a literal. + + Returns: + Func: A Func object that represents the least function. + + Example: + ```py + dc.mutate( + least=func.least("signal.value", 0), + ) + ``` + + Note: + - Result column will always be of the same type as the input columns. + """ + cols, func_args = [], [] + + for arg in args: + if isinstance(arg, (str, Func)): + cols.append(arg) + else: + func_args.append(arg) + + return Func( + "least", inner=conditional.least, cols=cols, args=func_args, result_type=int + ) diff --git a/src/datachain/lib/func/func.py b/src/datachain/lib/func/func.py index 3e7373d52..8eca09bbd 100644 --- a/src/datachain/lib/func/func.py +++ b/src/datachain/lib/func/func.py @@ -1,17 +1,25 @@ +import inspect +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from sqlalchemy import desc +from sqlalchemy import BindParameter, ColumnElement, desc +from sqlalchemy import func as sa_func from datachain.lib.convert.python_to_sql import python_to_sql from datachain.lib.utils import DataChainColumnError, DataChainParamsError from datachain.query.schema import Column, ColumnMeta +from .inner.base import Function + if TYPE_CHECKING: from datachain import DataType from datachain.lib.signal_schema import SignalSchema +ColT = Union[str, ColumnElement, "Func"] + + @dataclass class Window: """Represents a window specification for SQL window functions.""" @@ -58,14 +66,15 @@ def window(partition_by: str, order_by: str, desc: bool = False) -> Window: ) -class Func: +class Func(Function): """Represents a function to be applied to a column in a SQL query.""" def __init__( self, name: str, inner: Callable, - col: Optional[str] = None, + cols: Optional[Sequence[ColT]] = None, + args: Optional[Sequence[Any]] = None, result_type: Optional["DataType"] = None, is_array: bool = False, is_window: bool = False, @@ -73,7 +82,8 @@ def __init__( ) -> None: self.name = name self.inner = inner - self.col = col + self.cols = cols or [] + self.args = args or [] self.result_type = result_type self.is_array = is_array self.is_window = is_window @@ -89,7 +99,8 @@ def over(self, window: Window) -> "Func": return Func( "over", self.inner, - self.col, + self.cols, + self.args, self.result_type, self.is_array, self.is_window, @@ -97,20 +108,71 @@ def over(self, window: Window) -> "Func": ) @property - def db_col(self) -> Optional[str]: - return ColumnMeta.to_db_name(self.col) if self.col else None + def _db_cols(self) -> Sequence[ColT]: + return ( + [ + col + if isinstance(col, (Func, BindParameter)) + else ColumnMeta.to_db_name( + col.name if isinstance(col, ColumnElement) else col + ) + for col in self.cols + ] + if self.cols + else [] + ) - def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]: - if not self.db_col: + def _db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]: + if not self._db_cols: return None - col_type: type = signals_schema.get_column_type(self.db_col) + + col_type: type = get_db_col_type(signals_schema, self._db_cols[0]) + for col in self._db_cols[1:]: + if get_db_col_type(signals_schema, col) != col_type: + raise DataChainColumnError( + str(self), + "Columns must have the same type to infer result type", + ) + return list[col_type] if self.is_array else col_type # type: ignore[valid-type] - def get_result_type(self, signals_schema: "SignalSchema") -> "DataType": + def __add__(self, other: Union[ColT, float]) -> "Func": + return sum(self, other) + + def __radd__(self, other: Union[ColT, float]) -> "Func": + return sum(other, self) + + def __sub__(self, other: Union[ColT, float]) -> "Func": + return sub(self, other) + + def __rsub__(self, other: Union[ColT, float]) -> "Func": + return sub(other, self) + + def __mul__(self, other: Union[ColT, float]) -> "Func": + return multiply(self, other) + + def __rmul__(self, other: Union[ColT, float]) -> "Func": + return multiply(other, self) + + def __truediv__(self, other: Union[ColT, float]) -> "Func": + return divide(self, other) + + def __rtruediv__(self, other: Union[ColT, float]) -> "Func": + return divide(other, self) + + def __gt__(self, other: Union[ColT, float]) -> "Func": + return gt(self, other) + + def __lt__(self, other: Union[ColT, float]) -> "Func": + return lt(self, other) + + def get_result_type( + self, signals_schema: Optional["SignalSchema"] = None + ) -> "DataType": if self.result_type: return self.result_type - if col_type := self.db_col_type(signals_schema): + if signals_schema and (col_type := self._db_col_type(signals_schema)): return col_type raise DataChainColumnError( @@ -119,16 +181,22 @@ def get_result_type(self, signals_schema: "SignalSchema") -> "DataType": ) def get_column( - self, signals_schema: "SignalSchema", label: Optional[str] = None + self, + signals_schema: Optional["SignalSchema"] = None, + label: Optional[str] = None, ) -> Column: col_type = self.get_result_type(signals_schema) sql_type = python_to_sql(col_type) - if self.col: - col = Column(self.db_col, sql_type) - func_col = self.inner(col) - else: - func_col = self.inner() + cols = [ + col.get_column(signals_schema) + if isinstance(col, Func) + else Column(col, sql_type) + if isinstance(col, str) + else col + for col in self._db_cols + ] + func_col = self.inner(*cols, *self.args) if self.is_window: if not self.window: @@ -144,9 +212,90 @@ def get_column( ), ) - func_col.type = sql_type + func_col.type = sql_type() if inspect.isclass(sql_type) else sql_type if label: func_col = func_col.label(label) return func_col + + +def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType": + if isinstance(col, Func): + return col.get_result_type(signals_schema) + + return signals_schema.get_column_type( + col.name if isinstance(col, ColumnElement) else col + ) + + +def sum(*args: Union[ColT, float]) -> Func: + """Computes the sum of the column.""" + cols, func_args = [], [] + for arg in args: + if isinstance(arg, (int, float)): + func_args.append(arg) + else: + cols.append(arg) + + return Func("sum", sa_func.sum, cols=cols, args=func_args) + + +def sub(*args: Union[ColT, float]) -> Func: + """Computes the diff of the column.""" + cols, func_args = [], [] + for arg in args: + if isinstance(arg, (int, float)): + func_args.append(arg) + else: + cols.append(arg) + + return Func("sub", sa_func.sub, cols=cols, args=func_args) + + +def multiply(*args: Union[ColT, float]) -> Func: + """Computes the product of the column.""" + cols, func_args = [], [] + for arg in args: + if isinstance(arg, (int, float)): + func_args.append(arg) + else: + cols.append(arg) + + return Func("multiply", sa_func.multiply, cols=cols, args=func_args) + + +def divide(*args: Union[ColT, float]) -> Func: + """Computes the division of the column.""" + cols, func_args = [], [] + for arg in args: + if isinstance(arg, (int, float)): + func_args.append(arg) + else: + cols.append(arg) + + return Func("divide", sa_func.divide, cols=cols, args=func_args) + + +def gt(*args: Union[ColT, float]) -> Func: + """Computes the greater than comparison of the column.""" + cols, func_args = [], [] + for arg in args: + if isinstance(arg, (int, float)): + func_args.append(arg) + else: + cols.append(arg) + + return Func("divide", sa_func.gt, cols=cols, args=func_args) + + +def lt(*args: Union[ColT, float]) -> Func: + """Computes the less than comparison of the column.""" + cols, func_args = [], [] + for arg in args: + if isinstance(arg, (int, float)): + func_args.append(arg) + else: + cols.append(arg) + + return Func("divide", sa_func.lt, cols=cols, args=func_args) diff --git a/src/datachain/lib/func/inner/__init__.py b/src/datachain/lib/func/inner/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/datachain/sql/functions/aggregate.py b/src/datachain/lib/func/inner/aggregate.py similarity index 100% rename from src/datachain/sql/functions/aggregate.py rename to src/datachain/lib/func/inner/aggregate.py diff --git a/src/datachain/sql/functions/array.py b/src/datachain/lib/func/inner/array.py similarity index 100% rename from src/datachain/sql/functions/array.py rename to src/datachain/lib/func/inner/array.py diff --git a/src/datachain/lib/func/inner/base.py b/src/datachain/lib/func/inner/base.py new file mode 100644 index 000000000..3d746a7d5 --- /dev/null +++ b/src/datachain/lib/func/inner/base.py @@ -0,0 +1,20 @@ +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from datachain.lib.signal_schema import SignalSchema + from datachain.query.schema import Column + + +class Function: + __metaclass__ = ABCMeta + + name: str + + @abstractmethod + def get_column( + self, + signals_schema: Optional["SignalSchema"] = None, + label: Optional[str] = None, + ) -> "Column": + pass diff --git a/src/datachain/sql/functions/conditional.py b/src/datachain/lib/func/inner/conditional.py similarity index 100% rename from src/datachain/sql/functions/conditional.py rename to src/datachain/lib/func/inner/conditional.py diff --git a/src/datachain/sql/functions/path.py b/src/datachain/lib/func/inner/path.py similarity index 100% rename from src/datachain/sql/functions/path.py rename to src/datachain/lib/func/inner/path.py diff --git a/src/datachain/sql/functions/random.py b/src/datachain/lib/func/inner/random.py similarity index 100% rename from src/datachain/sql/functions/random.py rename to src/datachain/lib/func/inner/random.py diff --git a/src/datachain/sql/functions/string.py b/src/datachain/lib/func/inner/string.py similarity index 100% rename from src/datachain/sql/functions/string.py rename to src/datachain/lib/func/inner/string.py diff --git a/src/datachain/lib/func/path.py b/src/datachain/lib/func/path.py new file mode 100644 index 000000000..ea3ba445a --- /dev/null +++ b/src/datachain/lib/func/path.py @@ -0,0 +1,109 @@ +from .func import ColT, Func +from .inner import path + + +def parent(col: ColT) -> Func: + """ + Returns the directory component of a posix-style path. + + Args: + col (str | literal | Func): String to compute the path parent of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the path parent function. + + Example: + ```py + dc.mutate( + parent=func.path.parent("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + return Func("parent", inner=path.parent, cols=[col], result_type=str) + + +def name(col: ColT) -> Func: + """ + Returns the final component of a posix-style path. + + Args: + col (str | literal): String to compute the path name of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the path name function. + + Example: + ```py + dc.mutate( + file_name=func.path.name("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + return Func("name", inner=path.name, cols=[col], result_type=str) + + +def file_stem(col: ColT) -> Func: + """ + Returns the path without the extension. + + Args: + col (str | literal): String to compute the file stem of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the file stem function. + + Example: + ```py + dc.mutate( + file_stem=func.path.file_stem("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + return Func("file_stem", inner=path.file_stem, cols=[col], result_type=str) + + +def file_ext(col: ColT) -> Func: + """ + Returns the extension of the given path. + + Args: + col (str | literal): String to compute the file extension of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the file extension function. + + Example: + ```py + dc.mutate( + file_stem=func.path.file_ext("file.path"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + return Func("file_ext", inner=path.file_ext, cols=[col], result_type=str) diff --git a/src/datachain/lib/func/random.py b/src/datachain/lib/func/random.py new file mode 100644 index 000000000..7160d093d --- /dev/null +++ b/src/datachain/lib/func/random.py @@ -0,0 +1,22 @@ +from .func import Func +from .inner import random + + +def rand() -> Func: + """ + Returns the random integer value. + + Returns: + Func: A Func object that represents the rand function. + + Example: + ```py + dc.mutate( + rnd=func.random.rand(), + ) + ``` + + Note: + - Result column will always be of type integer. + """ + return Func("rand", inner=random.rand, result_type=int) diff --git a/src/datachain/lib/func/string.py b/src/datachain/lib/func/string.py new file mode 100644 index 000000000..de9f225f6 --- /dev/null +++ b/src/datachain/lib/func/string.py @@ -0,0 +1,153 @@ +from typing import Optional, Union, get_origin + +from sqlalchemy import literal + +from .func import Func +from .inner import string + + +def length(col: Union[str, Func]) -> Func: + """ + Returns the length of the string. + + Args: + col (str | literal | Func): String to compute the length of. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + + Returns: + Func: A Func object that represents the string length function. + + Example: + ```py + dc.mutate( + len1=func.string.length("file.path"), + len2=func.string.length("Random string"), + ) + ``` + + Note: + - Result column will always be of type int. + """ + return Func("length", inner=string.length, cols=[col], result_type=int) + + +def split(col: Union[str, Func], sep: str, limit: Optional[int] = None) -> Func: + """ + Takes a column and split character and returns an array of the parts. + + Args: + col (str | literal): Column to split. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + sep (str): Separator to split the string. + limit (int, optional): Maximum number of splits to perform. + + Returns: + Func: A Func object that represents the split function. + + Example: + ```py + dc.mutate( + path_parts=func.string.split("file.path", "/"), + str_words=func.string.length("Random string", " "), + ) + ``` + + Note: + - Result column will always be of type array of strings. + """ + + def inner(arg): + if limit is not None: + return string.split(arg, sep, limit) + return string.split(arg, sep) + + if get_origin(col) is literal: + cols = None + args = [col] + else: + cols = [col] + args = None + + return Func("split", inner=inner, cols=cols, args=args, result_type=list[str]) + + +def replace(col: Union[str, Func], pattern: str, replacement: str) -> Func: + """ + Replaces substring with another string. + + Args: + col (str | literal): Column to split. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + pattern (str): Pattern to replace. + replacement (str): Replacement string. + + Returns: + Func: A Func object that represents the replace function. + + Example: + ```py + dc.mutate( + signal=func.string.replace("signal.name", "pattern", "replacement), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + def inner(arg): + return string.replace(arg, pattern, replacement) + + if get_origin(col) is literal: + cols = None + args = [col] + else: + cols = [col] + args = None + + return Func("replace", inner=inner, cols=cols, args=args, result_type=str) + + +def regexp_replace(col: Union[str, Func], regex: str, replacement: str) -> Func: + r""" + Replaces substring that match a regular expression. + + Args: + col (str | literal): Column to split. + If a string is provided, it is assumed to be the name of the column. + If a literal is provided, it is assumed to be a string literal. + If a Func is provided, it is assumed to be a function returning a string. + regex (str): Regular expression pattern to replace. + replacement (str): Replacement string. + + Returns: + Func: A Func object that represents the regexp_replace function. + + Example: + ```py + dc.mutate( + signal=func.string.regexp_replace("signal.name", r"\d+", "X"), + ) + ``` + + Note: + - Result column will always be of type string. + """ + + def inner(arg): + return string.regexp_replace(arg, regex, replacement) + + if get_origin(col) is literal: + cols = None + args = [col] + else: + cols = [col] + args = None + + return Func("regexp_replace", inner=inner, cols=cols, args=args, result_type=str) diff --git a/src/datachain/lib/listing.py b/src/datachain/lib/listing.py index bfb87afc9..7d8678d1c 100644 --- a/src/datachain/lib/listing.py +++ b/src/datachain/lib/listing.py @@ -8,8 +8,8 @@ from datachain.asyn import iter_over_async from datachain.client import Client from datachain.lib.file import File +from datachain.lib.func.inner import path as pathfunc from datachain.query.schema import Column -from datachain.sql.functions import path as pathfunc from datachain.telemetry import telemetry from datachain.utils import uses_glob diff --git a/src/datachain/listing.py b/src/datachain/listing.py index bc61d8277..e7b22218b 100644 --- a/src/datachain/listing.py +++ b/src/datachain/listing.py @@ -8,8 +8,8 @@ from sqlalchemy.sql import func from tqdm import tqdm +from datachain.lib.func.inner import path as pathfunc from datachain.node import DirType, Node, NodeWithPath -from datachain.sql.functions import path as pathfunc from datachain.utils import suffix_to_number if TYPE_CHECKING: diff --git a/src/datachain/nodes_fetcher.py b/src/datachain/nodes_fetcher.py index b57f21542..5aa75dc3c 100644 --- a/src/datachain/nodes_fetcher.py +++ b/src/datachain/nodes_fetcher.py @@ -2,12 +2,12 @@ from collections.abc import Iterable from typing import TYPE_CHECKING -from datachain.node import Node from datachain.nodes_thread_pool import NodesThreadPool if TYPE_CHECKING: from datachain.cache import DataChainCache from datachain.client.fsspec import Client + from datachain.node import Node logger = logging.getLogger("datachain") @@ -22,7 +22,7 @@ def done_task(self, done): for task in done: task.result() - def do_task(self, chunk: Iterable[Node]) -> None: + def do_task(self, chunk: Iterable["Node"]) -> None: from fsspec import Callback class _CB(Callback): diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 52c8f082b..cec1a0f4a 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -43,9 +43,10 @@ ) from datachain.dataset import DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError +from datachain.lib.func.inner.base import Function +from datachain.lib.func.inner.random import rand from datachain.lib.udf import UDFAdapter from datachain.progress import CombinedDownloadCallback -from datachain.sql.functions import rand from datachain.utils import ( batched, determine_processes, @@ -65,8 +66,7 @@ from datachain.catalog import Catalog from datachain.data_storage import AbstractWarehouse from datachain.dataset import DatasetRecord - - from .udf import UDFResult + from datachain.lib.udf import UDFResult P = ParamSpec("P") @@ -685,6 +685,12 @@ def q(*columns): return step_result(q, new_query.selected_columns) + def parse_cols( + self, + cols: Sequence[Union[Function, ColumnElement]], + ) -> tuple[ColumnElement, ...]: + return tuple(c.get_column() if isinstance(c, Function) else c for c in cols) + @abstractmethod def apply_sql_clause(self, query): pass @@ -692,12 +698,14 @@ def apply_sql_clause(self, query): @frozen class SQLSelect(SQLClause): - args: tuple[Union[str, ColumnElement], ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query) -> Select: subquery = query.subquery() - - args = [subquery.c[str(c)] if isinstance(c, (str, C)) else c for c in self.args] + args = [ + subquery.c[str(c)] if isinstance(c, (str, C)) else c + for c in self.parse_cols(self.args) + ] if not args: args = subquery.c @@ -706,22 +714,25 @@ def apply_sql_clause(self, query) -> Select: @frozen class SQLSelectExcept(SQLClause): - args: tuple[str, ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query: Select) -> Select: subquery = query.subquery() - names = set(self.args) - args = [c for c in subquery.c if c.name not in names] + args = [c for c in subquery.c if c.name not in set(self.parse_cols(self.args))] return sqlalchemy.select(*args).select_from(subquery) @frozen class SQLMutate(SQLClause): - args: tuple[ColumnElement, ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query: Select) -> Select: original_subquery = query.subquery() - to_mutate = {c.name for c in self.args} + args = [ + original_subquery.c[str(c)] if isinstance(c, (str, C)) else c + for c in self.parse_cols(self.args) + ] + to_mutate = {c.name for c in args} prefix = f"mutate{token_hex(8)}_" cols = [ @@ -731,9 +742,7 @@ def apply_sql_clause(self, query: Select) -> Select: # this is needed for new column to be used in clauses # like ORDER BY, otherwise new column is not recognized subquery = ( - sqlalchemy.select(*cols, *self.args) - .select_from(original_subquery) - .subquery() + sqlalchemy.select(*cols, *args).select_from(original_subquery).subquery() ) return sqlalchemy.select(*subquery.c).select_from(subquery) @@ -741,21 +750,24 @@ def apply_sql_clause(self, query: Select) -> Select: @frozen class SQLFilter(SQLClause): - expressions: tuple[ColumnElement, ...] + expressions: tuple[Union[Function, ColumnElement], ...] def __and__(self, other): - return self.__class__(self.expressions + other) + expressions = self.parse_cols(self.expressions) + return self.__class__(expressions + other) def apply_sql_clause(self, query: Select) -> Select: - return query.filter(*self.expressions) + expressions = self.parse_cols(self.expressions) + return query.filter(*expressions) @frozen class SQLOrderBy(SQLClause): - args: tuple[ColumnElement, ...] + args: tuple[Union[Function, ColumnElement], ...] def apply_sql_clause(self, query: Select) -> Select: - return query.order_by(*self.args) + args = self.parse_cols(self.args) + return query.order_by(*args) @frozen diff --git a/src/datachain/sql/__init__.py b/src/datachain/sql/__init__.py index 4fc757e4c..4d812300e 100644 --- a/src/datachain/sql/__init__.py +++ b/src/datachain/sql/__init__.py @@ -1,13 +1,11 @@ from sqlalchemy.sql.elements import literal from sqlalchemy.sql.expression import column -from . import functions from .default import setup as default_setup from .selectable import select, values __all__ = [ "column", - "functions", "literal", "select", "values", diff --git a/src/datachain/sql/functions/__init__.py b/src/datachain/sql/functions/__init__.py deleted file mode 100644 index c8d4ef0de..000000000 --- a/src/datachain/sql/functions/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -from sqlalchemy.sql.expression import func - -from . import array, path, string -from .aggregate import avg -from .conditional import greatest, least -from .random import rand - -count = func.count -sum = func.sum -min = func.min -max = func.max - -__all__ = [ - "array", - "avg", - "count", - "func", - "greatest", - "least", - "max", - "min", - "path", - "rand", - "string", - "sum", -] diff --git a/src/datachain/sql/functions2/__init__.py b/src/datachain/sql/functions2/__init__.py new file mode 100644 index 000000000..b2156c5cf --- /dev/null +++ b/src/datachain/sql/functions2/__init__.py @@ -0,0 +1,16 @@ +from .conditional import greatest, least + +__all__ = [ + # "array", + # "avg", + # "count", + # "func", + "greatest", + "least", + # "max", + # "min", + # "path", + # "rand", + # "string", + # "sum", +] diff --git a/src/datachain/sql/functions2/aggregate.py b/src/datachain/sql/functions2/aggregate.py new file mode 100644 index 000000000..dab916a42 --- /dev/null +++ b/src/datachain/sql/functions2/aggregate.py @@ -0,0 +1,47 @@ +from sqlalchemy.sql.functions import GenericFunction, ReturnTypeFromArgs + +from datachain.sql.types import Float, String +from datachain.sql.utils import compiler_not_implemented + + +class avg(GenericFunction): # noqa: N801 + """ + Returns the average of the column. + """ + + type = Float() + package = "array" + name = "avg" + inherit_cache = True + + +class group_concat(GenericFunction): # noqa: N801 + """ + Returns the concatenated string of the column. + """ + + type = String() + package = "array" + name = "group_concat" + inherit_cache = True + + +class any_value(ReturnTypeFromArgs): # noqa: N801 + """ + Returns first value of the column. + """ + + inherit_cache = True + + +class collect(ReturnTypeFromArgs): # noqa: N801 + """ + Returns an array of the column. + """ + + inherit_cache = True + + +compiler_not_implemented(avg) +compiler_not_implemented(group_concat) +compiler_not_implemented(any_value) diff --git a/src/datachain/sql/functions2/array.py b/src/datachain/sql/functions2/array.py new file mode 100644 index 000000000..567162fe6 --- /dev/null +++ b/src/datachain/sql/functions2/array.py @@ -0,0 +1,50 @@ +from sqlalchemy.sql.functions import GenericFunction + +from datachain.sql.types import Float, Int64 +from datachain.sql.utils import compiler_not_implemented + + +class cosine_distance(GenericFunction): # noqa: N801 + """ + Takes a column and array and returns the cosine distance between them. + """ + + type = Float() + package = "array" + name = "cosine_distance" + inherit_cache = True + + +class euclidean_distance(GenericFunction): # noqa: N801 + """ + Takes a column and array and returns the Euclidean distance between them. + """ + + type = Float() + package = "array" + name = "euclidean_distance" + inherit_cache = True + + +class length(GenericFunction): # noqa: N801 + """ + Returns the length of the array. + """ + + type = Int64() + package = "array" + name = "length" + inherit_cache = True + + +class sip_hash_64(GenericFunction): # noqa: N801 + type = Int64() + package = "hash" + name = "sip_hash_64" + inherit_cache = True + + +compiler_not_implemented(cosine_distance) +compiler_not_implemented(euclidean_distance) +compiler_not_implemented(length) +compiler_not_implemented(sip_hash_64) diff --git a/src/datachain/sql/functions2/base.py b/src/datachain/sql/functions2/base.py new file mode 100644 index 000000000..23de4dc33 --- /dev/null +++ b/src/datachain/sql/functions2/base.py @@ -0,0 +1,18 @@ +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from datachain.lib.signal_schema import SignalSchema + from datachain.query.schema import Column + + +class Function: + __metaclass__ = ABCMeta + + @abstractmethod + def get_column( + self, + signals_schema: Optional["SignalSchema"] = None, + label: Optional[str] = None, + ) -> "Column": + pass diff --git a/src/datachain/sql/functions2/conditional.py b/src/datachain/sql/functions2/conditional.py new file mode 100644 index 000000000..42b4fea3b --- /dev/null +++ b/src/datachain/sql/functions2/conditional.py @@ -0,0 +1,9 @@ +from sqlalchemy.sql.functions import ReturnTypeFromArgs + + +class greatest(ReturnTypeFromArgs): # noqa: N801 + inherit_cache = True + + +class least(ReturnTypeFromArgs): # noqa: N801 + inherit_cache = True diff --git a/src/datachain/sql/functions2/path.py b/src/datachain/sql/functions2/path.py new file mode 100644 index 000000000..687a3318b --- /dev/null +++ b/src/datachain/sql/functions2/path.py @@ -0,0 +1,61 @@ +""" +This module provides generic SQL functions for path logic. + +These need to be implemented using dialect-specific compilation rules. +See https://docs.sqlalchemy.org/en/14/core/compiler.html +""" + +from sqlalchemy.sql.functions import GenericFunction + +from datachain.sql.types import String +from datachain.sql.utils import compiler_not_implemented + + +class parent(GenericFunction): # noqa: N801 + """ + Returns the directory component of a posix-style path. + """ + + type = String() + package = "path" + name = "parent" + inherit_cache = True + + +class name(GenericFunction): # noqa: N801 + """ + Returns the final component of a posix-style path. + """ + + type = String() + package = "path" + name = "name" + inherit_cache = True + + +class file_stem(GenericFunction): # noqa: N801 + """ + Strips an extension from the given path. + """ + + type = String() + package = "path" + name = "file_stem" + inherit_cache = True + + +class file_ext(GenericFunction): # noqa: N801 + """ + Returns the extension of the given path. + """ + + type = String() + package = "path" + name = "file_ext" + inherit_cache = True + + +compiler_not_implemented(parent) +compiler_not_implemented(name) +compiler_not_implemented(file_stem) +compiler_not_implemented(file_ext) diff --git a/src/datachain/sql/functions2/random.py b/src/datachain/sql/functions2/random.py new file mode 100644 index 000000000..29dc7e5f5 --- /dev/null +++ b/src/datachain/sql/functions2/random.py @@ -0,0 +1,12 @@ +from sqlalchemy.sql.functions import GenericFunction + +from datachain.sql.types import Int64 +from datachain.sql.utils import compiler_not_implemented + + +class rand(GenericFunction): # noqa: N801 + type = Int64() + inherit_cache = True + + +compiler_not_implemented(rand) diff --git a/src/datachain/sql/functions2/string.py b/src/datachain/sql/functions2/string.py new file mode 100644 index 000000000..4ccee8444 --- /dev/null +++ b/src/datachain/sql/functions2/string.py @@ -0,0 +1,54 @@ +from sqlalchemy.sql.functions import GenericFunction + +from datachain.sql.types import Array, Int64, String +from datachain.sql.utils import compiler_not_implemented + + +class length(GenericFunction): # noqa: N801 + """ + Returns the length of the string. + """ + + type = Int64() + package = "string" + name = "length" + inherit_cache = True + + +class split(GenericFunction): # noqa: N801 + """ + Takes a column and split character and returns an array of the parts. + """ + + type = Array(String()) + package = "string" + name = "split" + inherit_cache = True + + +class regexp_replace(GenericFunction): # noqa: N801 + """ + Replaces substring that match a regular expression. + """ + + type = String() + package = "string" + name = "regexp_replace" + inherit_cache = True + + +class replace(GenericFunction): # noqa: N801 + """ + Replaces substring with another string. + """ + + type = String() + package = "string" + name = "replace" + inherit_cache = True + + +compiler_not_implemented(length) +compiler_not_implemented(split) +compiler_not_implemented(regexp_replace) +compiler_not_implemented(replace) diff --git a/src/datachain/sql/selectable.py b/src/datachain/sql/selectable.py index 20c0b5f0c..e0602124b 100644 --- a/src/datachain/sql/selectable.py +++ b/src/datachain/sql/selectable.py @@ -1,6 +1,8 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import expression, selectable +from datachain.lib.func.inner.base import Function + class Values(selectable.Values): def __init__(self, data, columns=None, **kwargs): @@ -9,7 +11,9 @@ def __init__(self, data, columns=None, **kwargs): columns = [expression.column(f"c{i}") for i in range(1, num_columns + 1)] else: columns = [ - expression.column(c) if isinstance(c, str) else c for c in columns + process_column_expression(c) + for c in columns + # expression.column(c) if isinstance(c, str) else c for c in columns ] super().__init__(*columns, **kwargs) self._data += tuple(data) @@ -19,13 +23,17 @@ def values(data, columns=None, **kwargs) -> Values: return Values(data, columns=columns, **kwargs) -def process_column_expressions(columns): - return [expression.column(c) if isinstance(c, str) else c for c in columns] +def process_column_expression(col): + if isinstance(col, Function): + return col.get_column() + if isinstance(col, str): + return expression.column(col) + return col def select(*columns, **kwargs) -> "expression.Select": - columns = process_column_expressions(columns) - return expression.select(*columns, **kwargs) + columns_processed = [process_column_expression(c) for c in columns] + return expression.select(*columns_processed, **kwargs) def base_values_compiler(column_name_func, element, compiler, **kwargs): diff --git a/src/datachain/sql/sqlite/base.py b/src/datachain/sql/sqlite/base.py index d99e49de1..e19a87223 100644 --- a/src/datachain/sql/sqlite/base.py +++ b/src/datachain/sql/sqlite/base.py @@ -14,8 +14,8 @@ from sqlalchemy.sql.expression import case from sqlalchemy.sql.functions import func -from datachain.sql.functions import aggregate, array, conditional, random, string -from datachain.sql.functions import path as sql_path +from datachain.lib.func.inner import aggregate, array, conditional, random, string +from datachain.lib.func.inner import path as sql_path from datachain.sql.selectable import Values, base_values_compiler from datachain.sql.sqlite.types import ( SQLiteTypeConverter, diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index a9d968b82..71d029da3 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -15,19 +15,18 @@ from PIL import Image from sqlalchemy import Column -from datachain import DataModel +from datachain import DataModel, func from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE from datachain.data_storage.sqlite import SQLiteWarehouse from datachain.dataset import DatasetDependencyType, DatasetStats from datachain.lib.dc import C, DataChain from datachain.lib.file import File, ImageFile +from datachain.lib.func import path as pathfunc from datachain.lib.listing import LISTING_TTL, is_listing_dataset, parse_listing_uri from datachain.lib.tar import process_tar from datachain.lib.udf import Mapper from datachain.lib.utils import DataChainError from datachain.query.dataset import QueryStep -from datachain.sql.functions import path as pathfunc -from datachain.sql.functions.array import cosine_distance, euclidean_distance from tests.utils import ( ANY_VALUE, NUM_TREE, @@ -927,7 +926,7 @@ def get_result(chain): expected = [(f"{i:06d}", i) for i in range(100)] dc = ( DataChain.from_storage(ctc.src_uri, session=ctc.session) - .mutate(name=pathfunc.name(C("file.path"))) + .mutate(name=pathfunc.name("file.path")) .save() ) # We test a few different orderings here, because we've had strange @@ -1262,8 +1261,8 @@ def calc_emb(file): DataChain.from_storage(src_uri, session=session) .map(embedding=calc_emb, output={"embedding": list[float]}) .mutate( - cos_dist=cosine_distance(C("embedding"), target_embedding), - eucl_dist=euclidean_distance(C("embedding"), target_embedding), + cos_dist=func.cosine_distance("embedding", target_embedding), + eucl_dist=func.euclidean_distance("embedding", target_embedding), ) .order_by("file.path") ) diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index b83330b61..a8797c915 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -13,8 +13,8 @@ DatasetNotFoundError, DatasetVersionNotFoundError, ) +from datachain.lib.func.inner import path as pathfunc from datachain.query import C, DatasetQuery, Object, Stream -from datachain.sql.functions import path as pathfunc from datachain.sql.types import String from tests.utils import assert_row_names, dataset_dependency_asdict diff --git a/tests/func/test_datasets.py b/tests/func/test_datasets.py index 9c54f0c0e..94a5564a4 100644 --- a/tests/func/test_datasets.py +++ b/tests/func/test_datasets.py @@ -15,7 +15,7 @@ from datachain.lib.dc import DataChain from datachain.lib.file import File from datachain.lib.listing import parse_listing_uri -from datachain.query import DatasetQuery +from datachain.query.dataset import DatasetQuery from datachain.sql.types import Float32, Int, Int64 from tests.utils import assert_row_names, dataset_dependency_asdict diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index 53e818963..5c20551cf 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -107,7 +107,7 @@ def remote_dataset_version(schema, dataset_rows): "schema": schema, "sources": "", "query_script": ( - 'from datachain.query import DatasetQuery\nDatasetQuery(path="s3://ldb-public")', + 'from datachain.query.dataset import DatasetQuery\nDatasetQuery(path="s3://ldb-public")', ), "created_by_id": 1, } diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 2841ec9b2..8621d163d 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1809,7 +1809,7 @@ def test_order_by_with_nested_columns(test_session, with_function): file=[File(path=name) for name in names], session=test_session ) if with_function: - from datachain.sql.functions import rand + from datachain.lib.func.inner.random import rand dc = dc.order_by("file.path", rand()) else: @@ -1832,7 +1832,7 @@ def test_order_by_descending(test_session, with_function): file=[File(path=name) for name in names], session=test_session ) if with_function: - from datachain.sql.functions import rand + from datachain.lib.func.inner.random import rand dc = dc.order_by("file.path", rand(), descending=True) else: @@ -2187,22 +2187,20 @@ def test_mutate_with_multiplication(test_session): def test_mutate_with_sql_func(test_session): - from datachain.sql import functions as func + from datachain import func ds = DataChain.from_values(id=[1, 2], session=test_session) - assert ( - ds.mutate(new=func.avg(ds.column("id"))).signals_schema.values["new"] is float - ) + assert ds.mutate(new=func.avg("id")).signals_schema.values["new"] is float def test_mutate_with_complex_expression(test_session): - from datachain.sql import functions as func + from datachain import func ds = DataChain.from_values(id=[1, 2], name=["Jim", "Jon"], session=test_session) assert ( - ds.mutate( - new=(func.sum(ds.column("id"))) * (5 - func.min(ds.column("id"))) - ).signals_schema.values["new"] + ds.mutate(new=func.sum("id") * (5 - func.min("id"))).signals_schema.values[ + "new" + ] is int ) diff --git a/tests/unit/lib/test_sql_to_python.py b/tests/unit/lib/test_sql_to_python.py index 85c973ac9..80565ef22 100644 --- a/tests/unit/lib/test_sql_to_python.py +++ b/tests/unit/lib/test_sql_to_python.py @@ -3,7 +3,6 @@ from datachain import Column from datachain.lib.convert.sql_to_python import sql_to_python -from datachain.sql import functions as func from datachain.sql.types import Float, Int64, String @@ -15,8 +14,6 @@ (Column("score", Float), float), # SQL expression (Column("age", Int64) - 2, int), - # SQL function - (func.avg(Column("age", Int64)), float), # Default type (Column("null", NullType), str), ], diff --git a/tests/unit/sql/test_array.py b/tests/unit/sql/test_array.py index 5448c82c6..fa265a223 100644 --- a/tests/unit/sql/test_array.py +++ b/tests/unit/sql/test_array.py @@ -1,12 +1,65 @@ -from datachain.sql import literal, select -from datachain.sql.functions import array, string +import math + +import pytest + +from datachain import func +from datachain.sql import select + + +def test_cosine_distance(warehouse): + query = select( + func.cosine_distance((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 5, 6)), + func.cosine_distance([3.0, 5.0, 1.0], (3.0, 5.0, 1.0)), + func.cosine_distance((1, 0), [0, 10]), + func.cosine_distance([0.0, 10.0], [1.0, 0.0]), + ) + result = tuple(warehouse.db.execute(query)) + assert result == ((0.0, 0.0, 1.0, 1.0),) + + +def test_euclidean_distance(warehouse): + query = select( + func.euclidean_distance((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 5, 6)), + func.euclidean_distance([3.0, 5.0, 1.0], (3.0, 5.0, 1.0)), + func.euclidean_distance((1, 0), [0, 1]), + func.euclidean_distance([1.0, 1.0, 1.0], [2.0, 2.0, 2.0]), + ) + result = tuple(warehouse.db.execute(query)) + assert result == ((0.0, 0.0, math.sqrt(2), math.sqrt(3)),) + + +@pytest.mark.parametrize( + "args", + [ + [], + ["signal"], + [[1, 2]], + [[1, 2], [1, 2], [1, 2]], + ["signal1", "signal2", "signal3"], + ["signal1", "signal2", [1, 2]], + ], +) +def test_cosine_euclidean_distance_error_args(warehouse, args): + with pytest.raises(ValueError, match="requires exactly two arguments"): + func.cosine_distance(*args) + + with pytest.raises(ValueError, match="requires exactly two arguments"): + func.euclidean_distance(*args) + + +def test_cosine_euclidean_distance_error_vectors_length(warehouse): + with pytest.raises(ValueError, match="requires vectors of the same length"): + func.cosine_distance([1], [1, 2]) + + with pytest.raises(ValueError, match="requires vectors of the same length"): + func.euclidean_distance([1], [1, 2]) def test_length(warehouse): query = select( - array.length(["abc", "def", "g", "hi"]), - array.length([3.0, 5.0, 1.0, 6.0, 1.0]), - array.length([[1, 2, 3], [4, 5, 6]]), + func.length(["abc", "def", "g", "hi"]), + func.length([3.0, 5.0, 1.0, 6.0, 1.0]), + func.length([[1, 2, 3], [4, 5, 6]]), ) result = tuple(warehouse.db.execute(query)) assert result == ((4, 5, 2),) @@ -14,7 +67,7 @@ def test_length(warehouse): def test_length_on_split(warehouse): query = select( - array.length(string.split(literal("abc/def/g/hi"), literal("/"))), + func.array.length(func.string.split(func.literal("abc/def/g/hi"), "/")), ) result = tuple(warehouse.db.execute(query)) assert result == ((4,),) diff --git a/tests/unit/sql/test_conditional.py b/tests/unit/sql/test_conditional.py index cae3f2433..db64511fd 100644 --- a/tests/unit/sql/test_conditional.py +++ b/tests/unit/sql/test_conditional.py @@ -1,20 +1,27 @@ import pytest -from datachain.sql import column, select, values -from datachain.sql import literal as lit -from datachain.sql.functions import greatest, least +from datachain import func +from datachain.sql import select, values @pytest.mark.parametrize( "args,expected", [ - ([lit("abc"), lit("bcd"), lit("Abc"), lit("cd")], "cd"), + ( + [ + func.literal("abc"), + func.literal("bcd"), + func.literal("Abc"), + func.literal("cd"), + ], + "cd", + ), ([3, 1, 2.0, 3.1, 2.5, -1], 3.1), ([4], 4), ], ) def test_greatest(warehouse, args, expected): - query = select(greatest(*args)) + query = select(func.greatest(*args)) result = tuple(warehouse.db.execute(query)) assert result == ((expected,),) @@ -22,13 +29,21 @@ def test_greatest(warehouse, args, expected): @pytest.mark.parametrize( "args,expected", [ - ([lit("abc"), lit("bcd"), lit("Abc"), lit("cd")], "Abc"), + ( + [ + func.literal("abc"), + func.literal("bcd"), + func.literal("Abc"), + func.literal("cd"), + ], + "Abc", + ), ([3, 1, 2.0, 3.1, 2.5, -1], -1), ([4], 4), ], ) def test_least(warehouse, args, expected): - query = select(least(*args)) + query = select(func.least(*args)) result = tuple(warehouse.db.execute(query)) assert result == ((expected,),) @@ -36,9 +51,9 @@ def test_least(warehouse, args, expected): @pytest.mark.parametrize( "expr,expected", [ - (greatest(column("a")), [(3,), (8,), (9,)]), - (least(column("a")), [(3,), (8,), (9,)]), - (least(column("a"), column("b")), [(3,), (7,), (1,)]), + (func.greatest("a"), [(3,), (8,), (9,)]), + (func.least("a"), [(3,), (8,), (9,)]), + (func.least("a", "b"), [(3,), (7,), (1,)]), ], ) def test_conditionals_with_multiple_rows(warehouse, expr, expected): diff --git a/tests/unit/sql/test_path.py b/tests/unit/sql/test_path.py index 7f138d333..e0670cc4a 100644 --- a/tests/unit/sql/test_path.py +++ b/tests/unit/sql/test_path.py @@ -2,10 +2,11 @@ import re import pytest -from sqlalchemy import literal, select -from sqlalchemy.sql import func as f +from sqlalchemy import func as sa_func -from datachain.sql.functions import path as sql_path +from datachain import func +from datachain.lib.func.inner import path as sql_path +from datachain.sql import select PATHS = ["", "/", "name", "/name", "name/", "some/long/path"] EXT_PATHS = [ @@ -33,43 +34,43 @@ def file_ext(path): return pp.splitext(path)[1].lstrip(".") -@pytest.mark.parametrize("func_base", [f.path, sql_path]) +@pytest.mark.parametrize("func_base", [sa_func.path, sql_path]) @pytest.mark.parametrize("func_name", ["parent", "name"]) def test_default_not_implement(func_base, func_name): """ - Importing datachain.sql.functions.path should register a custom compiler + Importing datachain.lib.func.inner.path should register a custom compiler which raises an exception for these functions with the default SQLAlchemy dialect. """ fn = getattr(func_base, func_name) - expr = fn(literal("file:///some/file/path")) + expr = fn(func.literal("file:///some/file/path")) with pytest.raises(NotImplementedError, match=re.escape(f"path.{func_name}")): expr.compile() @pytest.mark.parametrize("path", PATHS) def test_parent(warehouse, path): - query = select(f.path.parent(literal(path))) + query = select(func.path.parent(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((split_parent(path)[0],),) @pytest.mark.parametrize("path", PATHS) def test_name(warehouse, path): - query = select(f.path.name(literal(path))) + query = select(func.path.name(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((split_parent(path)[1],),) @pytest.mark.parametrize("path", EXT_PATHS) def test_file_stem(warehouse, path): - query = select(sql_path.file_stem(literal(path))) + query = select(func.path.file_stem(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((file_stem(path),),) @pytest.mark.parametrize("path", EXT_PATHS) def test_file_ext(warehouse, path): - query = select(sql_path.file_ext(literal(path))) + query = select(func.path.file_ext(func.literal(path))) result = tuple(warehouse.db.execute(query)) assert result == ((file_ext(path),),) diff --git a/tests/unit/sql/test_random.py b/tests/unit/sql/test_random.py index 6e486fbc7..46ca8493d 100644 --- a/tests/unit/sql/test_random.py +++ b/tests/unit/sql/test_random.py @@ -1,8 +1,8 @@ +from datachain import func from datachain.sql import select -from datachain.sql.functions import rand def test_rand(warehouse): - query = select(rand()) + query = select(func.random.rand()) result = tuple(warehouse.db.execute(query)) assert isinstance(result[0][0], int) diff --git a/tests/unit/sql/test_string.py b/tests/unit/sql/test_string.py index 2f67d181d..bd61c85ec 100644 --- a/tests/unit/sql/test_string.py +++ b/tests/unit/sql/test_string.py @@ -1,7 +1,7 @@ import pytest -from datachain.sql import literal, select -from datachain.sql.functions import string +from datachain.lib.func import literal, string +from datachain.sql import select def test_length(warehouse): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 33c680ee7..8d9f9fbdf 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -4,7 +4,8 @@ import sqlalchemy as sa from datachain.error import DatasetNotFoundError -from datachain.query import DatasetQuery, Session +from datachain.query.dataset import DatasetQuery +from datachain.query.session import Session from datachain.sql.types import String