diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index 950ce34e..4048848d 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -1,4 +1,3 @@ -import re from collections.abc import Sequence from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Optional @@ -13,6 +12,7 @@ from datachain.lib.model_store import ModelStore from datachain.lib.signal_schema import SignalSchema from datachain.lib.udf import Generator +from datachain.lib.utils import normalize_col_names if TYPE_CHECKING: from datasets.features.features import Features @@ -128,7 +128,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non signal_schema = _get_datachain_schema(schema) if signal_schema: return signal_schema.values - columns = _convert_col_names(col_names) # type: ignore[arg-type] + columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type] hf_schema = _get_hf_schema(schema) if hf_schema: return { @@ -143,19 +143,6 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non return output -def _convert_col_names(col_names: Sequence[str]) -> list[str]: - default_column = 0 - converted_col_names = [] - for column in col_names: - column = column.lower() - column = re.sub("[^0-9a-z_]+", "", column) - if not column: - column = f"c{default_column}" - default_column += 1 - converted_col_names.append(column) - return converted_col_names - - def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911 """Convert pyarrow types to basic types.""" from datetime import datetime diff --git a/src/datachain/lib/data_model.py b/src/datachain/lib/data_model.py index 67f95f29..9cb84a91 100644 --- a/src/datachain/lib/data_model.py +++ b/src/datachain/lib/data_model.py @@ -2,9 +2,10 @@ from datetime import datetime from typing import ClassVar, Union, get_args, get_origin -from pydantic import BaseModel, create_model +from pydantic import BaseModel, Field, create_model from datachain.lib.model_store import ModelStore +from datachain.lib.utils import normalize_col_names StandardType = Union[ type[int], @@ -60,7 +61,14 @@ def is_chain_type(t: type) -> bool: def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]: - fields = {name: (anno, ...) for name, anno in data_dict.items()} + # Gets a map of a normalized_name -> original_name + columns = normalize_col_names(list(data_dict.keys())) + # We reverse if for convenience to original_name -> normalized_name + columns = {v: k for k, v in columns.items()} + + fields = { + columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items() + } return create_model( name, __base__=(DataModel,), # type: ignore[call-overload] diff --git a/src/datachain/lib/utils.py b/src/datachain/lib/utils.py index cd11da9c..b61bc6fa 100644 --- a/src/datachain/lib/utils.py +++ b/src/datachain/lib/utils.py @@ -1,4 +1,6 @@ +import re from abc import ABC, abstractmethod +from collections.abc import Sequence class AbstractUDF(ABC): @@ -28,3 +30,31 @@ def __init__(self, message): class DataChainColumnError(DataChainParamsError): def __init__(self, col_name, msg): super().__init__(f"Error for column {col_name}: {msg}") + + +def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]: + gen_col_counter = 0 + new_col_names = {} + org_col_names = set(col_names) + + for org_column in col_names: + new_column = org_column.lower() + new_column = re.sub("[^0-9a-z]+", "_", new_column) + new_column = new_column.strip("_") + + generated_column = new_column + + while ( + not generated_column.isidentifier() + or generated_column in new_col_names + or (generated_column != org_column and generated_column in org_col_names) + ): + if new_column: + generated_column = f"c{gen_col_counter}_{new_column}" + else: + generated_column = f"c{gen_col_counter}" + gen_col_counter += 1 + + new_col_names[generated_column] = org_column + + return new_col_names diff --git a/tests/unit/lib/test_arrow.py b/tests/unit/lib/test_arrow.py index 4d1414b9..f3a2ef9a 100644 --- a/tests/unit/lib/test_arrow.py +++ b/tests/unit/lib/test_arrow.py @@ -168,13 +168,21 @@ def test_parquet_convert_column_names(): ("dot.notation.col", pa.int32()), ("with-dashes", pa.int32()), ("with spaces", pa.int32()), + ("with-multiple--dashes", pa.int32()), + ("with__underscores", pa.int32()), + ("__leading__underscores", pa.int32()), + ("trailing__underscores__", pa.int32()), ] ) assert list(schema_to_output(schema)) == [ "uppercasecol", - "dotnotationcol", - "withdashes", - "withspaces", + "dot_notation_col", + "with_dashes", + "with_spaces", + "with_multiple_dashes", + "with_underscores", + "leading_underscores", + "trailing_underscores", ] diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index c07f69ce..2841ec9b 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -36,6 +36,18 @@ "city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], } +DF_DATA_NESTED_NOT_NORMALIZED = { + "nAmE": [ + {"first-SELECT": "Alice", "l--as@t": "Smith"}, + {"l--as@t": "Jones", "first-SELECT": "Bob"}, + {"first-SELECT": "Charlie", "l--as@t": "Brown"}, + {"first-SELECT": "David", "l--as@t": "White"}, + {"first-SELECT": "Eva", "l--as@t": "Black"}, + ], + "AgE": [25, 30, 35, 40, 45], + "citY": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], +} + DF_OTHER_DATA = { "last_name": ["Smith", "Jones"], "country": ["USA", "Russia"], @@ -272,7 +284,9 @@ def test_listings(test_session, tmp_dir): assert listing.expires assert listing.version == 1 assert listing.num_objects == 1 - assert listing.size == 2912 + # Exact number if unreliable here since it depends on the PyArrow version + assert listing.size > 1000 + assert listing.size < 5000 assert listing.status == 4 @@ -988,6 +1002,25 @@ def test_parse_tabular_format(tmp_dir, test_session): assert df1.equals(df) +def test_parse_nested_json(tmp_dir, test_session): + df = pd.DataFrame(DF_DATA_NESTED_NOT_NORMALIZED) + path = tmp_dir / "test.jsonl" + path.write_text(df.to_json(orient="records", lines=True)) + dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular( + format="json" + ) + # Field names are normalized, values are preserved + # E.g. nAmE -> name, l--as@t -> l_as_t, etc + df1 = dc.select("name", "age", "city").to_pandas() + + assert df1["name"]["first_select"].to_list() == [ + d["first-SELECT"] for d in df["nAmE"].to_list() + ] + assert df1["name"]["l_as_t"].to_list() == [ + d["l--as@t"] for d in df["nAmE"].to_list() + ] + + def test_parse_tabular_partitions(tmp_dir, test_session): df = pd.DataFrame(DF_DATA) path = tmp_dir / "test.parquet" diff --git a/tests/unit/lib/test_utils.py b/tests/unit/lib/test_utils.py index 944ca720..83b423e4 100644 --- a/tests/unit/lib/test_utils.py +++ b/tests/unit/lib/test_utils.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from datachain.lib.convert.python_to_sql import python_to_sql +from datachain.lib.utils import normalize_col_names from datachain.sql.types import JSON, Array, String @@ -56,3 +57,72 @@ def test_convert_type_to_datachain_array(typ, expected): def test_convert_type_to_datachain_error(typ): with pytest.raises(TypeError): python_to_sql(typ) + + +def test_normalize_column_names(): + res = normalize_col_names( + [ + "UpperCase", + "_underscore_start", + "double__underscore", + "1start_with_number", + "не_ascii_start", + " space_start", + "space_end ", + "dash-end-", + "-dash-start", + "--multiple--dash--", + "-_ mix_ -dash_ -", + "__2digit_after_uderscore", + "", + "_-_- _---_ _", + "_-_- _---_ _1", + ] + ) + assert list(res.keys()) == [ + "uppercase", + "underscore_start", + "double_underscore", + "c0_1start_with_number", + "ascii_start", + "space_start", + "space_end", + "dash_end", + "dash_start", + "multiple_dash", + "mix_dash", + "c1_2digit_after_uderscore", + "c2", + "c3", + "c4_1", + ] + + +def test_normalize_column_names_case_repeat(): + res = normalize_col_names(["UpperCase", "UpPerCase"]) + + assert list(res.keys()) == ["uppercase", "c0_uppercase"] + + +def test_normalize_column_names_exists_after_normalize(): + res = normalize_col_names(["1digit", "c0_1digit"]) + + assert list(res.keys()) == ["c1_1digit", "c0_1digit"] + + +def test_normalize_column_names_normalized_repeat(): + res = normalize_col_names(["column", "_column"]) + + assert list(res.keys()) == ["column", "c0_column"] + + +def test_normalize_column_names_normalized_case_repeat(): + res = normalize_col_names(["CoLuMn", "_column"]) + + assert res == {"column": "CoLuMn", "c0_column": "_column"} + + +def test_normalize_column_names_repeat_generated_after_normalize(): + res = normalize_col_names(["c0_CoLuMn", "_column", "column"]) + + assert res == {"c0_column": "c0_CoLuMn", "c1_column": "_column", "column": "column"}