Skip to content

Commit

Permalink
datachain: support mutating existing column (#537)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Oct 28, 2024
1 parent cfe3d9c commit 1949d56
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 23 deletions.
8 changes: 0 additions & 8 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,14 +1111,6 @@ def mutate(self, **kwargs) -> "Self":
)
```
"""
existing_columns = set(self.signals_schema.values.keys())
for col_name in kwargs:
if col_name in existing_columns:
raise DataChainColumnError(
col_name,
"Cannot modify existing column with mutate(). "
"Use a different name for the new column.",
)
for col_name, expr in kwargs.items():
if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
raise DataChainColumnError(
Expand Down
9 changes: 4 additions & 5 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

from pydantic import BaseModel, create_model
from sqlalchemy import ColumnElement
from typing_extensions import Literal as LiteralEx

from datachain.lib.convert.python_to_sql import python_to_sql
Expand Down Expand Up @@ -491,16 +492,14 @@ def mutate(self, args_map: dict) -> "SignalSchema":
# renaming existing signal
del new_values[value.name]
new_values[name] = self.values[value.name]
elif name in self.values:
# changing the type of existing signal, e.g File -> ImageFile
del new_values[name]
new_values[name] = args_map[name]
elif isinstance(value, Func):
# adding new signal with function
new_values[name] = value.get_result_type(self)
else:
elif isinstance(value, ColumnElement):
# adding new signal
new_values[name] = sql_to_python(value)
else:
new_values[name] = value

return SignalSchema(new_values)

Expand Down
10 changes: 9 additions & 1 deletion src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Generator, Iterable, Iterator, Sequence
from copy import copy
from functools import wraps
from secrets import token_hex
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -720,10 +721,17 @@ class SQLMutate(SQLClause):

def apply_sql_clause(self, query: Select) -> Select:
original_subquery = query.subquery()
to_mutate = {c.name for c in self.args}

prefix = f"mutate{token_hex(8)}_"
cols = [
c.label(prefix + c.name) if c.name in to_mutate else c
for c in original_subquery.c
]
# this is needed for new column to be used in clauses
# like ORDER BY, otherwise new column is not recognized
subquery = (
sqlalchemy.select(*original_subquery.c, *self.args)
sqlalchemy.select(*cols, *self.args)
.select_from(original_subquery)
.subquery()
)
Expand Down
12 changes: 3 additions & 9 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from datachain.lib.listing import LISTING_TTL, is_listing_dataset, parse_listing_uri
from datachain.lib.tar import process_tar
from datachain.lib.udf import Mapper
from datachain.lib.utils import DataChainColumnError, DataChainError
from datachain.lib.utils import DataChainError
from datachain.query.dataset import QueryStep
from datachain.sql.functions import path as pathfunc
from datachain.sql.functions.array import cosine_distance, euclidean_distance
Expand Down Expand Up @@ -491,15 +491,9 @@ def test_from_storage_check_rows(tmp_dir, test_session):

def test_mutate_existing_column(test_session):
ds = DataChain.from_values(ids=[1, 2, 3], session=test_session)
ds = ds.mutate(ids=Column("ids") + 1)

with pytest.raises(DataChainColumnError) as excinfo:
ds.mutate(ids=Column("ids") + 1)

assert (
str(excinfo.value)
== "Error for column ids: Cannot modify existing column with mutate()."
" Use a different name for the new column."
)
assert list(ds.collect()) == [(2,), (3,), (4,)]


@pytest.mark.parametrize(
Expand Down

0 comments on commit 1949d56

Please sign in to comment.