Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish SQL functions refactoring #543

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions docs/references/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ for operations like [`DataChain.filter`](datachain.md#datachain.lib.dc.DataChain
and [`DataChain.mutate`](datachain.md#datachain.lib.dc.DataChain.mutate). Import
these functions from `datachain.sql.functions`.

::: datachain.sql.functions.avg
::: datachain.sql.functions.count
::: datachain.sql.functions.greatest
::: datachain.sql.functions.least
::: datachain.sql.functions.max
::: datachain.sql.functions.min
::: datachain.sql.functions.rand
::: datachain.sql.functions.sum
::: datachain.sql.functions.array
::: datachain.sql.functions.path
::: datachain.sql.functions.string
::: datachain.lib.func.avg
::: datachain.lib.func.count
::: datachain.lib.func.greatest
::: datachain.lib.func.least
::: datachain.lib.func.max
::: datachain.lib.func.min
::: datachain.lib.func.rand
::: datachain.lib.func.sum
::: datachain.lib.func.array
::: datachain.lib.func.path
::: datachain.lib.func.string
4 changes: 2 additions & 2 deletions examples/computer_vision/openimage-detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel

from datachain import C, DataChain, File
from datachain.sql.functions import path
from datachain.lib.func import path


class BBox(BaseModel):
Expand Down Expand Up @@ -54,7 +54,7 @@ def openimage_detect(args):
.filter(C("file.path").glob("*.jpg") | C("file.path").glob("*.json"))
.agg(
openimage_detect,
partition_by=path.file_stem(C("file.path")),
partition_by=path.file_stem("file.path"),
params=["file"],
output={"file": File, "bbox": BBox},
)
Expand Down
9 changes: 4 additions & 5 deletions examples/get_started/common_sql_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datachain import C, DataChain
from datachain.sql import literal
from datachain.sql.functions import array, greatest, least, path, string
from datachain.lib.func import array, greatest, least, path, string


def num_chars_udf(file):
Expand All @@ -18,7 +17,7 @@ def num_chars_udf(file):
(
dc.mutate(
length=string.length(path.name(C("file.path"))),
parts=string.split(path.name(C("file.path")), literal(".")),
parts=string.split(path.name(C("file.path")), "."),
)
.select("file.path", "length", "parts")
.show(5)
Expand All @@ -35,8 +34,8 @@ def num_chars_udf(file):


chain = dc.mutate(
a=array.length(string.split(C("file.path"), literal("/"))),
b=array.length(string.split(path.name(C("file.path")), literal("0"))),
a=array.length(string.split("file.path", "/")),
b=array.length(string.split(path.name("file.path"), "0")),
)

(
Expand Down
7 changes: 3 additions & 4 deletions examples/multimodal/clip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from torch.nn.functional import cosine_similarity
from torch.utils.data import DataLoader

from datachain import C, DataChain
from datachain.sql.functions import path
from datachain import C, DataChain, func

source = "gs://datachain-demo/50k-laion-files/000000/00000000*"

Expand All @@ -18,8 +17,8 @@ def create_dataset():
)
return imgs.merge(
captions,
on=path.file_stem(imgs.c("file.path")),
right_on=path.file_stem(captions.c("file.path")),
on=func.path.file_stem(imgs.c("file.path")),
right_on=func.path.file_stem(captions.c("file.path")),
)


Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/wds.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from datachain import DataChain
from datachain.lib.func import path
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta
from datachain.sql.functions import path

IMAGE_TARS = os.getenv(
"IMAGE_TARS", "gs://datachain-demo/datacomp-small/shards/000000[0-5]*.tar"
Expand Down
16 changes: 6 additions & 10 deletions examples/multimodal/wds_filtered.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import datachain.error
from datachain import C, DataChain
from datachain import C, DataChain, func
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion
from datachain.sql import literal
from datachain.sql.functions import array, greatest, least, string

name = "wds"
try:
Expand All @@ -20,14 +18,12 @@
wds.print_schema()

filtered = (
wds.filter(string.length(C("laion.txt")) > 5)
.filter(array.length(string.split(C("laion.txt"), literal(" "))) > 2)
wds.filter(func.string.length("laion.txt") > 5)
.filter(func.array.length(func.string.split("laion.txt", " ")) > 2)
.filter(func.least("laion.json.original_width", "laion.json.original_height") > 200)
.filter(
least(C("laion.json.original_width"), C("laion.json.original_height")) > 200
)
.filter(
greatest(C("laion.json.original_width"), C("laion.json.original_height"))
/ least(C("laion.json.original_width"), C("laion.json.original_height"))
func.greatest("laion.json.original_width", "laion.json.original_height")
/ func.least("laion.json.original_width", "laion.json.original_height")
< 3.0
)
.save()
Expand Down
14 changes: 8 additions & 6 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
QueryScriptCancelError,
QueryScriptRunError,
)
from datachain.listing import Listing
from datachain.node import DirType, Node, NodeWithPath
from datachain.nodes_thread_pool import NodesThreadPool
from datachain.remote.studio import StudioClient
Expand All @@ -76,6 +75,7 @@
from datachain.dataset import DatasetVersion
from datachain.job import Job
from datachain.lib.file import File
from datachain.listing import Listing

logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -241,7 +241,7 @@ def do_task(self, urls):
class NodeGroup:
"""Class for a group of nodes from the same source"""

listing: Listing
listing: "Listing"
sources: list[DataSource]

# The source path within the bucket
Expand Down Expand Up @@ -596,8 +596,9 @@ def enlist_source(
client_config=None,
object_name="file",
skip_indexing=False,
) -> tuple[Listing, str]:
) -> tuple["Listing", str]:
from datachain.lib.dc import DataChain
from datachain.listing import Listing

DataChain.from_storage(
source, session=self.session, update=update, object_name=object_name
Expand Down Expand Up @@ -664,7 +665,8 @@ def enlist_sources_grouped(
no_glob: bool = False,
client_config=None,
) -> list[NodeGroup]:
from datachain.query import DatasetQuery
from datachain.listing import Listing
from datachain.query.dataset import DatasetQuery

def _row_to_node(d: dict[str, Any]) -> Node:
del d["file__source"]
Expand Down Expand Up @@ -872,7 +874,7 @@ def create_new_dataset_version(
def update_dataset_version_with_warehouse_info(
self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
) -> None:
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

dataset_version = dataset.get_version(version)

Expand Down Expand Up @@ -1173,7 +1175,7 @@ def listings(self):
def ls_dataset_rows(
self, name: str, version: int, offset=None, limit=None
) -> list[dict]:
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

dataset = self.get_dataset(name)

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@
schema: bool = False,
) -> None:
from datachain.lib.dc import DataChain
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

Check warning on line 875 in src/datachain/cli.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/cli.py#L875

Added line #L875 was not covered by tests
from datachain.utils import show_records

dataset = catalog.get_dataset(name)
Expand Down
19 changes: 10 additions & 9 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
from datachain.cache import DataChainCache
from datachain.client.fileslice import FileWrapper
from datachain.error import ClientError as DataChainClientError
from datachain.lib.file import File
from datachain.nodes_fetcher import NodesFetcher
from datachain.nodes_thread_pool import NodeChunk
from datachain.storage import StorageURI

if TYPE_CHECKING:
from fsspec.spec import AbstractFileSystem

from datachain.lib.file import File


logger = logging.getLogger("datachain")

Expand All @@ -44,7 +45,7 @@

DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")

ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]


def _is_win_local_path(uri: str) -> bool:
Expand Down Expand Up @@ -207,7 +208,7 @@ async def get_file(self, lpath, rpath, callback):

async def scandir(
self, start_prefix: str, method: str = "default"
) -> AsyncIterator[Sequence[File]]:
) -> AsyncIterator[Sequence["File"]]:
try:
impl = getattr(self, f"_fetch_{method}")
except AttributeError:
Expand Down Expand Up @@ -312,7 +313,7 @@ def get_full_path(self, rel_path: str) -> str:
return f"{self.PREFIX}{self.name}/{rel_path}"

@abstractmethod
def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...
def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...

def fetch_nodes(
self,
Expand Down Expand Up @@ -349,27 +350,27 @@ def do_instantiate_object(self, file: "File", dst: str) -> None:
copy2(src, dst)

def open_object(
self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
) -> BinaryIO:
"""Open a file, including files in tar archives."""
if use_cache and (cache_path := self.cache.get_path(file)):
return open(cache_path, mode="rb") # noqa: SIM115
assert not file.location
return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]

def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
sync(get_loop(), functools.partial(self._download, file, callback=callback))

async def _download(self, file: File, *, callback: "Callback" = None) -> None:
async def _download(self, file: "File", *, callback: "Callback" = None) -> None:
if self.cache.contains(file):
# Already in cache, so there's nothing to do.
return
await self._put_in_cache(file, callback=callback)

def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))

async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
assert not file.location
if file.etag:
etag = await self.get_current_etag(file)
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.sql import func as f
from sqlalchemy.sql.expression import false, null, true

from datachain.sql.functions import path
from datachain.lib.func.inner import path
from datachain.sql.types import Int, SQLType, UInt64

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from datachain.data_storage.schema import convert_rows_custom_column_types
from datachain.data_storage.serializer import Serializable
from datachain.dataset import DatasetRecord
from datachain.lib.func.inner import path as pathfunc
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
from datachain.sql.functions import path as pathfunc
from datachain.sql.types import Int, SQLType
from datachain.storage import StorageURI
from datachain.utils import sql_escape_like
Expand Down
Loading