Skip to content

Commit

Permalink
Finish SQL functions refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Oct 30, 2024
1 parent 529c297 commit 26fb69c
Show file tree
Hide file tree
Showing 48 changed files with 1,147 additions and 250 deletions.
22 changes: 11 additions & 11 deletions docs/references/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ for operations like [`DataChain.filter`](datachain.md#datachain.lib.dc.DataChain
and [`DataChain.mutate`](datachain.md#datachain.lib.dc.DataChain.mutate). Import
these functions from `datachain.sql.functions`.

::: datachain.sql.functions.avg
::: datachain.sql.functions.count
::: datachain.sql.functions.greatest
::: datachain.sql.functions.least
::: datachain.sql.functions.max
::: datachain.sql.functions.min
::: datachain.sql.functions.rand
::: datachain.sql.functions.sum
::: datachain.sql.functions.array
::: datachain.sql.functions.path
::: datachain.sql.functions.string
::: datachain.lib.func.avg
::: datachain.lib.func.count
::: datachain.lib.func.greatest
::: datachain.lib.func.least
::: datachain.lib.func.max
::: datachain.lib.func.min
::: datachain.lib.func.rand
::: datachain.lib.func.sum
::: datachain.lib.func.array
::: datachain.lib.func.path
::: datachain.lib.func.string
4 changes: 2 additions & 2 deletions examples/computer_vision/openimage-detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel

from datachain import C, DataChain, File
from datachain.sql.functions import path
from datachain.lib.func import path


class BBox(BaseModel):
Expand Down Expand Up @@ -54,7 +54,7 @@ def openimage_detect(args):
.filter(C("file.path").glob("*.jpg") | C("file.path").glob("*.json"))
.agg(
openimage_detect,
partition_by=path.file_stem(C("file.path")),
partition_by=path.file_stem("file.path"),
params=["file"],
output={"file": File, "bbox": BBox},
)
Expand Down
9 changes: 4 additions & 5 deletions examples/get_started/common_sql_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datachain import C, DataChain
from datachain.sql import literal
from datachain.sql.functions import array, greatest, least, path, string
from datachain.lib.func import array, greatest, least, path, string


def num_chars_udf(file):
Expand All @@ -18,7 +17,7 @@ def num_chars_udf(file):
(
dc.mutate(
length=string.length(path.name(C("file.path"))),
parts=string.split(path.name(C("file.path")), literal(".")),
parts=string.split(path.name(C("file.path")), "."),
)
.select("file.path", "length", "parts")
.show(5)
Expand All @@ -35,8 +34,8 @@ def num_chars_udf(file):


chain = dc.mutate(
a=array.length(string.split(C("file.path"), literal("/"))),
b=array.length(string.split(path.name(C("file.path")), literal("0"))),
a=array.length(string.split("file.path", "/")),
b=array.length(string.split(path.name("file.path"), "0")),
)

(
Expand Down
7 changes: 3 additions & 4 deletions examples/multimodal/clip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from torch.nn.functional import cosine_similarity
from torch.utils.data import DataLoader

from datachain import C, DataChain
from datachain.sql.functions import path
from datachain import C, DataChain, func

source = "gs://datachain-demo/50k-laion-files/000000/00000000*"

Expand All @@ -18,8 +17,8 @@ def create_dataset():
)
return imgs.merge(
captions,
on=path.file_stem(imgs.c("file.path")),
right_on=path.file_stem(captions.c("file.path")),
on=func.path.file_stem(imgs.c("file.path")),
right_on=func.path.file_stem(captions.c("file.path")),
)


Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/wds.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from datachain import DataChain
from datachain.lib.func import path
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta
from datachain.sql.functions import path

IMAGE_TARS = os.getenv(
"IMAGE_TARS", "gs://datachain-demo/datacomp-small/shards/000000[0-5]*.tar"
Expand Down
16 changes: 6 additions & 10 deletions examples/multimodal/wds_filtered.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import datachain.error
from datachain import C, DataChain
from datachain import C, DataChain, func
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion
from datachain.sql import literal
from datachain.sql.functions import array, greatest, least, string

name = "wds"
try:
Expand All @@ -20,14 +18,12 @@
wds.print_schema()

filtered = (
wds.filter(string.length(C("laion.txt")) > 5)
.filter(array.length(string.split(C("laion.txt"), literal(" "))) > 2)
wds.filter(func.string.length("laion.txt") > 5)
.filter(func.array.length(func.string.split("laion.txt", " ")) > 2)
.filter(func.least("laion.json.original_width", "laion.json.original_height") > 200)
.filter(
least(C("laion.json.original_width"), C("laion.json.original_height")) > 200
)
.filter(
greatest(C("laion.json.original_width"), C("laion.json.original_height"))
/ least(C("laion.json.original_width"), C("laion.json.original_height"))
func.greatest("laion.json.original_width", "laion.json.original_height")
/ func.least("laion.json.original_width", "laion.json.original_height")
< 3.0
)
.save()
Expand Down
14 changes: 8 additions & 6 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
QueryScriptCancelError,
QueryScriptRunError,
)
from datachain.listing import Listing
from datachain.node import DirType, Node, NodeWithPath
from datachain.nodes_thread_pool import NodesThreadPool
from datachain.remote.studio import StudioClient
Expand All @@ -76,6 +75,7 @@
from datachain.dataset import DatasetVersion
from datachain.job import Job
from datachain.lib.file import File
from datachain.listing import Listing

logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -241,7 +241,7 @@ def do_task(self, urls):
class NodeGroup:
"""Class for a group of nodes from the same source"""

listing: Listing
listing: "Listing"
sources: list[DataSource]

# The source path within the bucket
Expand Down Expand Up @@ -596,8 +596,9 @@ def enlist_source(
client_config=None,
object_name="file",
skip_indexing=False,
) -> tuple[Listing, str]:
) -> tuple["Listing", str]:
from datachain.lib.dc import DataChain
from datachain.listing import Listing

DataChain.from_storage(
source, session=self.session, update=update, object_name=object_name
Expand Down Expand Up @@ -664,7 +665,8 @@ def enlist_sources_grouped(
no_glob: bool = False,
client_config=None,
) -> list[NodeGroup]:
from datachain.query import DatasetQuery
from datachain.listing import Listing
from datachain.query.dataset import DatasetQuery

def _row_to_node(d: dict[str, Any]) -> Node:
del d["file__source"]
Expand Down Expand Up @@ -872,7 +874,7 @@ def create_new_dataset_version(
def update_dataset_version_with_warehouse_info(
self, dataset: DatasetRecord, version: int, rows_dropped=False, **kwargs
) -> None:
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

dataset_version = dataset.get_version(version)

Expand Down Expand Up @@ -1173,7 +1175,7 @@ def listings(self):
def ls_dataset_rows(
self, name: str, version: int, offset=None, limit=None
) -> list[dict]:
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

dataset = self.get_dataset(name)

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def show(
schema: bool = False,
) -> None:
from datachain.lib.dc import DataChain
from datachain.query import DatasetQuery
from datachain.query.dataset import DatasetQuery

Check warning on line 875 in src/datachain/cli.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/cli.py#L875

Added line #L875 was not covered by tests
from datachain.utils import show_records

dataset = catalog.get_dataset(name)
Expand Down
19 changes: 10 additions & 9 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
from datachain.cache import DataChainCache
from datachain.client.fileslice import FileWrapper
from datachain.error import ClientError as DataChainClientError
from datachain.lib.file import File
from datachain.nodes_fetcher import NodesFetcher
from datachain.nodes_thread_pool import NodeChunk
from datachain.storage import StorageURI

if TYPE_CHECKING:
from fsspec.spec import AbstractFileSystem

from datachain.lib.file import File


logger = logging.getLogger("datachain")

Expand All @@ -44,7 +45,7 @@

DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")

ResultQueue = asyncio.Queue[Optional[Sequence[File]]]
ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]


def _is_win_local_path(uri: str) -> bool:
Expand Down Expand Up @@ -207,7 +208,7 @@ async def get_file(self, lpath, rpath, callback):

async def scandir(
self, start_prefix: str, method: str = "default"
) -> AsyncIterator[Sequence[File]]:
) -> AsyncIterator[Sequence["File"]]:
try:
impl = getattr(self, f"_fetch_{method}")
except AttributeError:
Expand Down Expand Up @@ -312,7 +313,7 @@ def get_full_path(self, rel_path: str) -> str:
return f"{self.PREFIX}{self.name}/{rel_path}"

@abstractmethod
def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...
def info_to_file(self, v: dict[str, Any], parent: str) -> "File": ...

def fetch_nodes(
self,
Expand Down Expand Up @@ -349,27 +350,27 @@ def do_instantiate_object(self, file: "File", dst: str) -> None:
copy2(src, dst)

def open_object(
self, file: File, use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
self, file: "File", use_cache: bool = True, cb: Callback = DEFAULT_CALLBACK
) -> BinaryIO:
"""Open a file, including files in tar archives."""
if use_cache and (cache_path := self.cache.get_path(file)):
return open(cache_path, mode="rb") # noqa: SIM115
assert not file.location
return FileWrapper(self.fs.open(self.get_full_path(file.path)), cb) # type: ignore[return-value]

def download(self, file: File, *, callback: Callback = DEFAULT_CALLBACK) -> None:
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
sync(get_loop(), functools.partial(self._download, file, callback=callback))

async def _download(self, file: File, *, callback: "Callback" = None) -> None:
async def _download(self, file: "File", *, callback: "Callback" = None) -> None:
if self.cache.contains(file):
# Already in cache, so there's nothing to do.
return
await self._put_in_cache(file, callback=callback)

def put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
def put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
sync(get_loop(), functools.partial(self._put_in_cache, file, callback=callback))

async def _put_in_cache(self, file: File, *, callback: "Callback" = None) -> None:
async def _put_in_cache(self, file: "File", *, callback: "Callback" = None) -> None:
assert not file.location
if file.etag:
etag = await self.get_current_etag(file)
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.sql import func as f
from sqlalchemy.sql.expression import false, null, true

from datachain.sql.functions import path
from datachain.lib.func.inner import path
from datachain.sql.types import Int, SQLType, UInt64

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from datachain.data_storage.schema import convert_rows_custom_column_types
from datachain.data_storage.serializer import Serializable
from datachain.dataset import DatasetRecord
from datachain.lib.func.inner import path as pathfunc
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
from datachain.sql.functions import path as pathfunc
from datachain.sql.types import Int, SQLType
from datachain.storage import StorageURI
from datachain.utils import sql_escape_like
Expand Down
Loading

0 comments on commit 26fb69c

Please sign in to comment.