Skip to content

Commit

Permalink
Fix CLI to work with DataChain new listing (#517)
Browse files Browse the repository at this point in the history
Fix CLI to work with `DataChain` new listing codebase
  • Loading branch information
ilongin authored Oct 28, 2024
1 parent 1949d56 commit 0eabe20
Show file tree
Hide file tree
Showing 23 changed files with 461 additions and 653 deletions.
237 changes: 49 additions & 188 deletions src/datachain/catalog/catalog.py

Large diffs are not rendered by default.

12 changes: 0 additions & 12 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
action="store_true",
help="AWS anon (aka awscli's --no-sign-request)",
)
parent_parser.add_argument(
"--ttl",
type=human_time_type,
default=TTL_HUMAN,
help="Time-to-live of data source cache. Negative equals forever.",
)
parent_parser.add_argument(
"-u", "--update", action="count", default=0, help="Update cache"
)
Expand Down Expand Up @@ -1011,7 +1005,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
edatachain_only=False,
no_edatachain_file=True,
no_glob=args.no_glob,
ttl=args.ttl,
)
elif args.command == "clone":
catalog.clone(
Expand All @@ -1021,7 +1014,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
update=bool(args.update),
recursive=bool(args.recursive),
no_glob=args.no_glob,
ttl=args.ttl,
no_cp=args.no_cp,
edatachain=args.edatachain,
edatachain_file=args.edatachain_file,
Expand All @@ -1047,7 +1039,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
args.sources,
long=bool(args.long),
remote=args.remote,
ttl=args.ttl,
update=bool(args.update),
client_config=client_config,
)
Expand Down Expand Up @@ -1081,15 +1072,13 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
show_bytes=args.bytes,
depth=args.depth,
si=args.si,
ttl=args.ttl,
update=bool(args.update),
client_config=client_config,
)
elif args.command == "find":
results_found = False
for result in catalog.find(
args.sources,
ttl=args.ttl,
update=bool(args.update),
names=args.name,
inames=args.iname,
Expand All @@ -1107,7 +1096,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
index(
catalog,
args.sources,
ttl=args.ttl,
update=bool(args.update),
)
elif args.command == "completion":
Expand Down
99 changes: 66 additions & 33 deletions src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
from sqlalchemy.sql.elements import ColumnElement


DEFAULT_DELIMITER = "__"


def col_name(name: str, object_name: str = "file") -> str:
return f"{object_name}{DEFAULT_DELIMITER}{name}"


def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
"""
Removes duplicate columns from a list of columns.
Expand Down Expand Up @@ -76,64 +83,81 @@ def convert_rows_custom_column_types(


class DirExpansion:
@staticmethod
def base_select(q):
def __init__(self, object_name: str):
self.object_name = object_name

def col_name(self, name: str, object_name: Optional[str] = None) -> str:
object_name = object_name or self.object_name
return col_name(name, object_name)

def c(self, query, name: str, object_name: Optional[str] = None) -> str:
return getattr(query.c, self.col_name(name, object_name=object_name))

def base_select(self, q):
return sa.select(
q.c.sys__id,
false().label("is_dir"),
q.c.source,
q.c.path,
q.c.version,
q.c.location,
self.c(q, "id", object_name="sys"),
false().label(self.col_name("is_dir")),
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "version"),
self.c(q, "location"),
)

@staticmethod
def apply_group_by(q):
def apply_group_by(self, q):
return (
sa.select(
f.min(q.c.sys__id).label("sys__id"),
q.c.is_dir,
q.c.source,
q.c.path,
q.c.version,
f.max(q.c.location).label("location"),
self.c(q, "is_dir"),
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "version"),
f.max(self.c(q, "location")).label(self.col_name("location")),
)
.select_from(q)
.group_by(q.c.source, q.c.path, q.c.is_dir, q.c.version)
.order_by(q.c.source, q.c.path, q.c.is_dir, q.c.version)
.group_by(
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "is_dir"),
self.c(q, "version"),
)
.order_by(
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "is_dir"),
self.c(q, "version"),
)
)

@classmethod
def query(cls, q):
q = cls.base_select(q).cte(recursive=True)
parent = path.parent(q.c.path)
def query(self, q):
q = self.base_select(q).cte(recursive=True)
parent = path.parent(self.c(q, "path"))
q = q.union_all(
sa.select(
sa.literal(-1).label("sys__id"),
true().label("is_dir"),
q.c.source,
parent.label("path"),
sa.literal("").label("version"),
null().label("location"),
true().label(self.col_name("is_dir")),
self.c(q, "source"),
parent.label(self.col_name("path")),
sa.literal("").label(self.col_name("version")),
null().label(self.col_name("location")),
).where(parent != "")
)
return cls.apply_group_by(q)
return self.apply_group_by(q)


class DataTable:
dataset_dir_expansion = staticmethod(DirExpansion.query)

def __init__(
self,
name: str,
engine: "Engine",
metadata: Optional["sa.MetaData"] = None,
column_types: Optional[dict[str, SQLType]] = None,
object_name: str = "file",
):
self.name: str = name
self.engine = engine
self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
self.column_types: dict[str, SQLType] = column_types or {}
self.object_name = object_name

@staticmethod
def copy_column(
Expand Down Expand Up @@ -204,9 +228,18 @@ def get_table(self) -> "sa.Table":
def columns(self) -> "ReadOnlyColumnCollection[str, sa.Column[Any]]":
return self.table.columns

@property
def c(self):
return self.columns
def col_name(self, name: str, object_name: Optional[str] = None) -> str:
object_name = object_name or self.object_name
return col_name(name, object_name)

def without_object(
self, column_name: str, object_name: Optional[str] = None
) -> str:
object_name = object_name or self.object_name
return column_name.removeprefix(f"{object_name}{DEFAULT_DELIMITER}")

def c(self, name: str, object_name: Optional[str] = None):
return getattr(self.columns, self.col_name(name, object_name=object_name))

@property
def table(self) -> "sa.Table":
Expand Down Expand Up @@ -246,7 +279,7 @@ def sys_columns():
]

def dir_expansion(self):
return self.dataset_dir_expansion(self)
return DirExpansion(self.object_name)


PARTITION_COLUMN_ID = "partition_id"
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def get_dataset_sources(
self, dataset: DatasetRecord, version: int
) -> list[StorageURI]:
dr = self.dataset_rows(dataset, version)
query = dr.select(dr.c.file__source).distinct()
query = dr.select(dr.c("source", object_name="file")).distinct()
cur = self.db.cursor()
cur.row_factory = sqlite3.Row # type: ignore[assignment]

Expand Down Expand Up @@ -671,13 +671,13 @@ def merge_dataset_rows(
# destination table doesn't exist, create it
self.create_dataset_rows_table(
self.dataset_table_name(dst.name, dst_version),
columns=src_dr.c,
columns=src_dr.columns,
)
dst_empty = True

dst_dr = self.dataset_rows(dst, dst_version).table
merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"]
select_src = select(*(getattr(src_dr.c, f) for f in merge_fields))
merge_fields = [c.name for c in src_dr.columns if c.name != "sys__id"]
select_src = select(*(getattr(src_dr.columns, f) for f in merge_fields))

if dst_empty:
# we don't need union, but just select from source to destination
Expand Down
Loading

0 comments on commit 0eabe20

Please sign in to comment.