From 07786ef4e0d53afe74e2a7f27cbd5c3bb331416f Mon Sep 17 00:00:00 2001 From: Devin Smith Date: Thu, 15 Aug 2024 16:04:39 -0700 Subject: [PATCH] feat!: Add TableDefinition wrapper for python (#5892) * `deephaven.table.TableDefinition`: new python wrapper for `io.deephaven.engine.table.TableDefinition` * `deephaven.column.ColumnDefinition`: new python wrapper for `io.deephaven.engine.table.ColumnDefinition` * `deephaven.table.TableDefinitionLike`: new type alias to allow for consistent public APIs and conversions into `deephaven.table.TableDefinition` * `deephaven.column.Column`: deprecated for removal * `deephaven.jcompat.j_table_definition`: deprecated for removal * `deephaven.stream.kafka.consumer.json_spec`: `cols_defs` specified as `List[Tuple[str, DType]]` deprecated for removal Fixes #4822 BREAKING CHANGE: `deephaven.column.InputColumn` no longer inherits from `deephaven.column.Column`; as such, it no longer exposes Column's attributes. This is unlikely to affect users as InputColumn is really a structure meant to support `new_table`. --- py/server/deephaven/_table_reader.py | 20 +- py/server/deephaven/column.py | 184 +++++++++--- py/server/deephaven/experimental/iceberg.py | 13 +- py/server/deephaven/jcompat.py | 46 +-- py/server/deephaven/learn/__init__.py | 4 +- py/server/deephaven/numpy.py | 18 +- py/server/deephaven/pandas.py | 14 +- py/server/deephaven/parquet.py | 18 +- py/server/deephaven/stream/kafka/consumer.py | 28 +- py/server/deephaven/stream/table_publisher.py | 12 +- py/server/deephaven/table.py | 256 ++++++++++++----- py/server/deephaven/table_factory.py | 13 +- py/server/tests/test_barrage.py | 4 +- py/server/tests/test_column.py | 40 ++- py/server/tests/test_csv.py | 9 +- py/server/tests/test_data_index.py | 2 +- py/server/tests/test_dbc.py | 10 +- py/server/tests/test_experiments.py | 8 +- py/server/tests/test_iceberg.py | 12 +- py/server/tests/test_numpy.py | 26 +- py/server/tests/test_parquet.py | 20 +- py/server/tests/test_partitioned_table.py | 9 +- py/server/tests/test_pt_proxy.py | 18 +- py/server/tests/test_table.py | 58 ++-- py/server/tests/test_table_definition.py | 271 ++++++++++++++++++ py/server/tests/test_table_factory.py | 36 ++- py/server/tests/test_table_iterator.py | 24 +- py/server/tests/test_table_listener.py | 5 +- py/server/tests/test_updateby.py | 12 +- py/server/tests/test_vectorization.py | 2 +- 30 files changed, 854 insertions(+), 338 deletions(-) create mode 100644 py/server/tests/test_table_definition.py diff --git a/py/server/deephaven/_table_reader.py b/py/server/deephaven/_table_reader.py index 4c9265745c5..49fa82c9772 100644 --- a/py/server/deephaven/_table_reader.py +++ b/py/server/deephaven/_table_reader.py @@ -9,7 +9,7 @@ import numpy as np from deephaven import update_graph -from deephaven.column import Column +from deephaven.column import ColumnDefinition from deephaven.jcompat import to_sequence from deephaven.numpy import _column_to_numpy_array from deephaven.table import Table @@ -18,7 +18,7 @@ T = TypeVar('T') -def _col_defs(table: Table, cols: Union[str, Sequence[str]]) -> Sequence[Column]: +def _col_defs(table: Table, cols: Union[str, Sequence[str]]) -> Sequence[ColumnDefinition]: if not cols: col_defs = table.columns else: @@ -31,7 +31,7 @@ def _col_defs(table: Table, cols: Union[str, Sequence[str]]) -> Sequence[Column] def _table_reader_all(table: Table, cols: Optional[Union[str, Sequence[str]]] = None, *, - emitter: Callable[[Sequence[Column], jpy.JType], T], row_set: jpy.JType, + emitter: Callable[[Sequence[ColumnDefinition], jpy.JType], T], row_set: jpy.JType, prev: bool = False) -> T: """ Reads all the rows in the given row set of a table. The emitter converts the Java data into a desired Python object. @@ -103,7 +103,7 @@ def _table_reader_all_dict(table: Table, cols: Optional[Union[str, Sequence[str] def _table_reader_chunk(table: Table, cols: Optional[Union[str, Sequence[str]]] = None, *, - emitter: Callable[[Sequence[Column], jpy.JType], Iterable[T]], row_set: jpy.JType, + emitter: Callable[[Sequence[ColumnDefinition], jpy.JType], Iterable[T]], row_set: jpy.JType, chunk_size: int = 2048, prev: bool = False) \ -> Generator[T, None, None]: """ Returns a generator that reads one chunk of rows at a time from the table. The emitter converts the Java chunk @@ -178,7 +178,7 @@ def _table_reader_chunk_dict(table: Table, cols: Optional[Union[str, Sequence[st Raises: ValueError """ - def _emitter(col_defs: Sequence[Column], j_array: jpy.JType) -> Generator[Dict[str, np.ndarray], None, None]: + def _emitter(col_defs: Sequence[ColumnDefinition], j_array: jpy.JType) -> Generator[Dict[str, np.ndarray], None, None]: yield {col_def.name: _column_to_numpy_array(col_def, j_array[i]) for i, col_def in enumerate(col_defs)} return _table_reader_chunk(table, cols, emitter=_emitter, row_set=row_set, chunk_size=chunk_size, prev=prev) @@ -210,9 +210,9 @@ def _table_reader_chunk_tuple(table: Table, cols: Optional[Union[str, Sequence[s Raises: ValueError """ - named_tuple_class = namedtuple(tuple_name, cols or [col.name for col in table.columns], rename=False) + named_tuple_class = namedtuple(tuple_name, cols or table.column_names, rename=False) - def _emitter(col_defs: Sequence[Column], j_array: jpy.JType) -> Generator[Tuple[np.ndarray], None, None]: + def _emitter(col_defs: Sequence[ColumnDefinition], j_array: jpy.JType) -> Generator[Tuple[np.ndarray], None, None]: yield named_tuple_class._make([_column_to_numpy_array(col_def, j_array[i]) for i, col_def in enumerate(col_defs)]) return _table_reader_chunk(table, cols, emitter=_emitter, row_set=table.j_table.getRowSet(), chunk_size=chunk_size, prev=False) @@ -242,7 +242,7 @@ def _table_reader_row_dict(table: Table, cols: Optional[Union[str, Sequence[str] Raises: ValueError """ - def _emitter(col_defs: Sequence[Column], j_array: jpy.JType) -> Iterable[Dict[str, Any]]: + def _emitter(col_defs: Sequence[ColumnDefinition], j_array: jpy.JType) -> Iterable[Dict[str, Any]]: make_dict = lambda values: {col_def.name: value for col_def, value in zip(col_defs, values)} mvs = [memoryview(j_array[i]) if col_def.data_type.is_primitive else j_array[i] for i, col_def in enumerate(col_defs)] return map(make_dict, zip(*mvs)) @@ -275,9 +275,9 @@ def _table_reader_row_tuple(table: Table, cols: Optional[Union[str, Sequence[str Raises: ValueError """ - named_tuple_class = namedtuple(tuple_name, cols or [col.name for col in table.columns], rename=False) + named_tuple_class = namedtuple(tuple_name, cols or table.column_names, rename=False) - def _emitter(col_defs: Sequence[Column], j_array: jpy.JType) -> Iterable[Tuple[Any, ...]]: + def _emitter(col_defs: Sequence[ColumnDefinition], j_array: jpy.JType) -> Iterable[Tuple[Any, ...]]: mvs = [memoryview(j_array[i]) if col_def.data_type.is_primitive else j_array[i] for i, col_def in enumerate(col_defs)] return map(named_tuple_class._make, zip(*mvs)) diff --git a/py/server/deephaven/column.py b/py/server/deephaven/column.py index bbb5f5008b9..88839233ff0 100644 --- a/py/server/deephaven/column.py +++ b/py/server/deephaven/column.py @@ -4,16 +4,17 @@ """ This module implements the Column class and functions that work with Columns. """ -from dataclasses import dataclass, field from enum import Enum -from typing import Sequence, Any +from functools import cached_property +from typing import Sequence, Any, Optional +from warnings import warn import jpy import deephaven.dtypes as dtypes from deephaven import DHError -from deephaven.dtypes import DType -from deephaven.dtypes import _instant_array +from deephaven.dtypes import DType, _instant_array, from_jtype +from deephaven._wrapper import JObjectWrapper _JColumnHeader = jpy.get_type("io.deephaven.qst.column.header.ColumnHeader") _JColumn = jpy.get_type("io.deephaven.qst.column.Column") @@ -32,46 +33,151 @@ def __repr__(self): return self.name -@dataclass -class Column: - """ A Column object represents a column definition in a Deephaven Table. """ - name: str - data_type: DType - component_type: DType = None - column_type: ColumnType = ColumnType.NORMAL +class ColumnDefinition(JObjectWrapper): + """A Deephaven column definition.""" - @property - def j_column_header(self): - return _JColumnHeader.of(self.name, self.data_type.qst_type) + j_object_type = _JColumnDefinition + + def __init__(self, j_column_definition: jpy.JType): + self.j_column_definition = j_column_definition @property - def j_column_definition(self): - if hasattr(self.data_type.j_type, 'jclass'): - j_data_type = self.data_type.j_type.jclass - else: - j_data_type = self.data_type.qst_type.clazz() - j_component_type = self.component_type.qst_type.clazz() if self.component_type else None - j_column_type = self.column_type.value - return _JColumnDefinition.fromGenericType(self.name, j_data_type, j_component_type, j_column_type) - - -@dataclass -class InputColumn(Column): - """ An InputColumn represents a user defined column with some input data. """ - input_data: Any = field(default=None) - - def __post_init__(self): + def j_object(self) -> jpy.JType: + return self.j_column_definition + + @cached_property + def name(self) -> str: + """The column name.""" + return self.j_column_definition.getName() + + @cached_property + def data_type(self) -> DType: + """The column data type.""" + return from_jtype(self.j_column_definition.getDataType()) + + @cached_property + def component_type(self) -> Optional[DType]: + """The column component type.""" + return from_jtype(self.j_column_definition.getComponentType()) + + @cached_property + def column_type(self) -> ColumnType: + """The column type.""" + return ColumnType(self.j_column_definition.getColumnType()) + + +class Column(ColumnDefinition): + """A Column object represents a column definition in a Deephaven Table. Deprecated for removal next release, prefer col_def.""" + + def __init__( + self, + name: str, + data_type: DType, + component_type: DType = None, + column_type: ColumnType = ColumnType.NORMAL, + ): + """Deprecated for removal next release, prefer col_def.""" + warn( + "Column is deprecated for removal next release, prefer col_def", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( + col_def(name, data_type, component_type, column_type).j_column_definition + ) + + +class InputColumn: + """An InputColumn represents a user defined column with some input data.""" + + def __init__( + self, + name: str = None, + data_type: DType = None, + component_type: DType = None, + column_type: ColumnType = ColumnType.NORMAL, + input_data: Any = None, + ): + """Creates an InputColumn. + Args: + name (str): the column name + data_type (DType): the column data type + component_type (Optional[DType]): the column component type, None by default + column_type (ColumnType): the column type, NORMAL by default + input_data: Any: the input data, by default is None + + Returns: + a new InputColumn + + Raises: + DHError + """ try: - if self.input_data is None: - self.j_column = _JColumn.empty(self.j_column_header) - else: - if self.data_type.is_primitive: - self.j_column = _JColumn.ofUnsafe(self.name, dtypes.array(self.data_type, self.input_data, - remap=dtypes.null_remap(self.data_type))) - else: - self.j_column = _JColumn.of(self.j_column_header, dtypes.array(self.data_type, self.input_data)) + self._column_definition = col_def( + name, data_type, component_type, column_type + ) + self.j_column = self._to_j_column(input_data) except Exception as e: - raise DHError(e, f"failed to create an InputColumn ({self.name}).") from e + raise DHError(e, f"failed to create an InputColumn ({name}).") from e + + def _to_j_column(self, input_data: Any = None) -> jpy.JType: + if input_data is None: + return _JColumn.empty( + _JColumnHeader.of( + self._column_definition.name, + self._column_definition.data_type.qst_type, + ) + ) + if self._column_definition.data_type.is_primitive: + return _JColumn.ofUnsafe( + self._column_definition.name, + dtypes.array( + self._column_definition.data_type, + input_data, + remap=dtypes.null_remap(self._column_definition.data_type), + ), + ) + return _JColumn.of( + _JColumnHeader.of( + self._column_definition.name, self._column_definition.data_type.qst_type + ), + dtypes.array(self._column_definition.data_type, input_data), + ) + + +def col_def( + name: str, + data_type: DType, + component_type: Optional[DType] = None, + column_type: ColumnType = ColumnType.NORMAL, +) -> ColumnDefinition: + """Creates a ColumnDefinition. + + Args: + name (str): the column name + data_type (DType): the column data type + component_type (Optional[DType]): the column component type, None by default + column_type (ColumnType): the column type, ColumnType.NORMAL by default + + Returns: + a new ColumnDefinition + + Raises: + DHError + """ + try: + return ColumnDefinition( + _JColumnDefinition.fromGenericType( + name, + data_type.j_type.jclass + if hasattr(data_type.j_type, "jclass") + else data_type.qst_type.clazz(), + component_type.qst_type.clazz() if component_type else None, + column_type.value, + ) + ) + except Exception as e: + raise DHError(e, f"failed to create a ColumnDefinition ({name}).") from e def bool_col(name: str, data: Sequence) -> InputColumn: diff --git a/py/server/deephaven/experimental/iceberg.py b/py/server/deephaven/experimental/iceberg.py index 7506bc95a25..0a99d3f1880 100644 --- a/py/server/deephaven/experimental/iceberg.py +++ b/py/server/deephaven/experimental/iceberg.py @@ -8,13 +8,8 @@ from deephaven import DHError from deephaven._wrapper import JObjectWrapper -from deephaven.column import Column -from deephaven.dtypes import DType from deephaven.experimental import s3 - -from deephaven.jcompat import j_table_definition - -from deephaven.table import Table +from deephaven.table import Table, TableDefinition, TableDefinitionLike _JIcebergInstructions = jpy.get_type("io.deephaven.iceberg.util.IcebergInstructions") _JIcebergCatalogAdapter = jpy.get_type("io.deephaven.iceberg.util.IcebergCatalogAdapter") @@ -39,14 +34,14 @@ class IcebergInstructions(JObjectWrapper): j_object_type = _JIcebergInstructions def __init__(self, - table_definition: Optional[Union[Dict[str, DType], List[Column]]] = None, + table_definition: Optional[TableDefinitionLike] = None, data_instructions: Optional[s3.S3Instructions] = None, column_renames: Optional[Dict[str, str]] = None): """ Initializes the instructions using the provided parameters. Args: - table_definition (Optional[Union[Dict[str, DType], List[Column], None]]): the table definition; if omitted, + table_definition (Optional[TableDefinitionLike]): the table definition; if omitted, the definition is inferred from the Iceberg schema. Setting a definition guarantees the returned table will have that definition. This is useful for specifying a subset of the Iceberg schema columns. data_instructions (Optional[s3.S3Instructions]): Special instructions for reading data files, useful when @@ -62,7 +57,7 @@ def __init__(self, builder = self.j_object_type.builder() if table_definition is not None: - builder.tableDefinition(j_table_definition(table_definition)) + builder.tableDefinition(TableDefinition(table_definition).j_table_definition) if data_instructions is not None: builder.dataInstructions(data_instructions.j_object) diff --git a/py/server/deephaven/jcompat.py b/py/server/deephaven/jcompat.py index d807cb472f3..98e6a04565e 100644 --- a/py/server/deephaven/jcompat.py +++ b/py/server/deephaven/jcompat.py @@ -5,7 +5,8 @@ """ This module provides Java compatibility support including convenience functions to create some widely used Java data structures from corresponding Python ones in order to be able to call Java methods. """ -from typing import Any, Callable, Dict, Iterable, List, Sequence, Set, TypeVar, Union, Optional +from typing import Any, Callable, Dict, Iterable, List, Sequence, Set, TypeVar, Union, Optional, Mapping +from warnings import warn import jpy import numpy as np @@ -14,7 +15,7 @@ from deephaven import dtypes, DHError from deephaven._wrapper import unwrap, wrap_j_object, JObjectWrapper from deephaven.dtypes import DType, _PRIMITIVE_DTYPE_NULL_MAP -from deephaven.column import Column +from deephaven.column import ColumnDefinition _NULL_BOOLEAN_AS_BYTE = jpy.get_type("io.deephaven.util.BooleanUtils").NULL_BOOLEAN_AS_BYTE _JPrimitiveArrayConversionUtility = jpy.get_type("io.deephaven.integrations.common.PrimitiveArrayConversionUtility") @@ -327,11 +328,20 @@ def _j_array_to_series(dtype: DType, j_array: jpy.JType, conv_null: bool) -> pd. return s -def j_table_definition(table_definition: Union[Dict[str, DType], List[Column], None]) -> Optional[jpy.JType]: - """Produce a Deephaven TableDefinition from user input. +# Note: unable to import TableDefinitionLike due to circular ref (table.py -> agg.py -> jcompat.py) +def j_table_definition( + table_definition: Union[ + "TableDefinition", + Mapping[str, dtypes.DType], + Iterable[ColumnDefinition], + jpy.JType, + None, + ], +) -> Optional[jpy.JType]: + """Deprecated for removal next release, prefer TableDefinition. Produce a Deephaven TableDefinition from user input. Args: - table_definition (Union[Dict[str, DType], List[Column], None]): the table definition as a dictionary of column + table_definition (Optional[TableDefinitionLike]): the table definition as a dictionary of column names and their corresponding data types or a list of Column objects Returns: @@ -340,22 +350,18 @@ def j_table_definition(table_definition: Union[Dict[str, DType], List[Column], N Raises: DHError """ - if table_definition is None: - return None - elif isinstance(table_definition, Dict): - return _JTableDefinition.of( - [ - Column(name=name, data_type=dtype).j_column_definition - for name, dtype in table_definition.items() - ] - ) - elif isinstance(table_definition, List): - return _JTableDefinition.of( - [col.j_column_definition for col in table_definition] - ) - else: - raise DHError(f"Unexpected table_definition type: {type(table_definition)}") + warn( + "j_table_definition is deprecated for removal next release, prefer TableDefinition", + DeprecationWarning, + stacklevel=2, + ) + from deephaven.table import TableDefinition + return ( + TableDefinition(table_definition).j_table_definition + if table_definition + else None + ) class AutoCloseable(JObjectWrapper): """A context manager wrapper to allow Java AutoCloseable to be used in with statements. diff --git a/py/server/deephaven/learn/__init__.py b/py/server/deephaven/learn/__init__.py index 54261ad1d21..45439cb6bf0 100644 --- a/py/server/deephaven/learn/__init__.py +++ b/py/server/deephaven/learn/__init__.py @@ -71,7 +71,7 @@ def _validate(inputs: Input, outputs: Output, table: Table): input_columns_list = [input_.input.getColNames()[i] for input_ in inputs for i in range(len(input_.input.getColNames()))] input_columns = set(input_columns_list) - table_columns = {col.name for col in table.columns} + table_columns = set(table.definition.keys()) if table_columns >= input_columns: if outputs is not None: output_columns_list = [output.output.getColName() for output in outputs] @@ -99,7 +99,7 @@ def _create_non_conflicting_col_name(table: Table, base_col_name: str) -> str: Returns: column name that is not present in the table. """ - table_col_names = set([col.name for col in table.columns]) + table_col_names = set(table.definition.keys()) if base_col_name not in table_col_names: return base_col_name else: diff --git a/py/server/deephaven/numpy.py b/py/server/deephaven/numpy.py index c87f24ea40c..33dcef3bc14 100644 --- a/py/server/deephaven/numpy.py +++ b/py/server/deephaven/numpy.py @@ -8,10 +8,10 @@ import jpy import numpy as np -from deephaven.dtypes import DType, BusinessCalendar +from deephaven.dtypes import BusinessCalendar from deephaven import DHError, dtypes, new_table -from deephaven.column import Column, InputColumn +from deephaven.column import InputColumn, ColumnDefinition from deephaven.dtypes import DType from deephaven.jcompat import _j_array_to_numpy_array from deephaven.table import Table @@ -27,11 +27,11 @@ def _to_column_name(name: str) -> str: return re.sub(r"\s+", "_", tmp_name) -def _column_to_numpy_array(col_def: Column, j_array: jpy.JType) -> np.ndarray: +def _column_to_numpy_array(col_def: ColumnDefinition, j_array: jpy.JType) -> np.ndarray: """ Produces a numpy array from the given Java array and the Table column definition. Args: - col_def (Column): the column definition + col_def (ColumnDefinition): the column definition j_array (jpy.JType): the Java array Returns: @@ -48,7 +48,7 @@ def _column_to_numpy_array(col_def: Column, j_array: jpy.JType) -> np.ndarray: raise DHError(e, f"failed to create a numpy array for the column {col_def.name}") from e -def _columns_to_2d_numpy_array(col_def: Column, j_arrays: List[jpy.JType]) -> np.ndarray: +def _columns_to_2d_numpy_array(col_def: ColumnDefinition, j_arrays: List[jpy.JType]) -> np.ndarray: """ Produces a 2d numpy array from the given Java arrays of the same component type and the Table column definition """ try: @@ -95,15 +95,15 @@ def to_numpy(table: Table, cols: List[str] = None) -> np.ndarray: if table.is_refreshing: table = table.snapshot() - col_def_dict = {col.name: col for col in table.columns} + table_def = table.definition if not cols: - cols = list(col_def_dict.keys()) + cols = list(table_def.keys()) else: - diff_set = set(cols) - set(col_def_dict.keys()) + diff_set = set(cols) - set(table_def.keys()) if diff_set: raise DHError(message=f"columns - {list(diff_set)} not found") - col_defs = [col_def_dict[col] for col in cols] + col_defs = [table_def[col] for col in cols] if len(set([col_def.data_type for col_def in col_defs])) != 1: raise DHError(message="columns must be of the same data type.") diff --git a/py/server/deephaven/pandas.py b/py/server/deephaven/pandas.py index d946de5e391..8a2be32c53a 100644 --- a/py/server/deephaven/pandas.py +++ b/py/server/deephaven/pandas.py @@ -11,7 +11,7 @@ import pyarrow as pa from deephaven import DHError, new_table, dtypes, arrow -from deephaven.column import Column +from deephaven.column import ColumnDefinition from deephaven.constants import NULL_BYTE, NULL_SHORT, NULL_INT, NULL_LONG, NULL_FLOAT, NULL_DOUBLE, NULL_CHAR from deephaven.jcompat import _j_array_to_series from deephaven.numpy import _make_input_column @@ -22,12 +22,12 @@ _is_dtype_backend_supported = pd.__version__ >= "2.0.0" -def _column_to_series(table: Table, col_def: Column, conv_null: bool) -> pd.Series: +def _column_to_series(table: Table, col_def: ColumnDefinition, conv_null: bool) -> pd.Series: """Produce a copy of the specified column as a pandas.Series object. Args: table (Table): the table - col_def (Column): the column definition + col_def (ColumnDefinition): the column definition conv_null (bool): whether to check for Deephaven nulls in the data and automatically replace them with pd.NA. @@ -133,17 +133,17 @@ def to_pandas(table: Table, cols: List[str] = None, if table.is_refreshing: table = table.snapshot() - col_def_dict = {col.name: col for col in table.columns} + table_def = table.definition if not cols: - cols = list(col_def_dict.keys()) + cols = list(table_def.keys()) else: - diff_set = set(cols) - set(col_def_dict.keys()) + diff_set = set(cols) - set(table_def.keys()) if diff_set: raise DHError(message=f"columns - {list(diff_set)} not found") data = {} for col in cols: - series = _column_to_series(table, col_def_dict[col], conv_null) + series = _column_to_series(table, table_def[col], conv_null) data[col] = series return pd.DataFrame(data=data, columns=cols, copy=False) diff --git a/py/server/deephaven/parquet.py b/py/server/deephaven/parquet.py index 61614c37061..0e7de6af920 100644 --- a/py/server/deephaven/parquet.py +++ b/py/server/deephaven/parquet.py @@ -11,10 +11,8 @@ import jpy from deephaven import DHError -from deephaven.column import Column -from deephaven.dtypes import DType -from deephaven.jcompat import j_array_list, j_table_definition -from deephaven.table import Table, PartitionedTable +from deephaven.jcompat import j_array_list +from deephaven.table import Table, TableDefinition, TableDefinitionLike, PartitionedTable from deephaven.experimental import s3 _JParquetTools = jpy.get_type("io.deephaven.parquet.table.ParquetTools") @@ -69,7 +67,7 @@ def _build_parquet_instructions( generate_metadata_files: Optional[bool] = None, base_name: Optional[str] = None, file_layout: Optional[ParquetFileLayout] = None, - table_definition: Optional[Union[Dict[str, DType], List[Column]]] = None, + table_definition: Optional[TableDefinitionLike] = None, index_columns: Optional[Sequence[Sequence[str]]] = None, special_instructions: Optional[s3.S3Instructions] = None, ): @@ -135,7 +133,7 @@ def _build_parquet_instructions( builder.setFileLayout(_j_file_layout(file_layout)) if table_definition is not None: - builder.setTableDefinition(j_table_definition(table_definition)) + builder.setTableDefinition(TableDefinition(table_definition).j_table_definition) if index_columns: builder.addAllIndexColumns(_j_list_of_list_of_string(index_columns)) @@ -166,7 +164,7 @@ def read( is_legacy_parquet: bool = False, is_refreshing: bool = False, file_layout: Optional[ParquetFileLayout] = None, - table_definition: Union[Dict[str, DType], List[Column], None] = None, + table_definition: Optional[TableDefinitionLike] = None, special_instructions: Optional[s3.S3Instructions] = None, ) -> Table: """ Reads in a table from a single parquet, metadata file, or directory with recognized layout. @@ -235,7 +233,7 @@ def delete(path: str) -> None: def write( table: Table, path: str, - table_definition: Optional[Union[Dict[str, DType], List[Column]]] = None, + table_definition: Optional[TableDefinitionLike] = None, col_instructions: Optional[List[ColumnInstruction]] = None, compression_codec_name: Optional[str] = None, max_dictionary_keys: Optional[int] = None, @@ -302,7 +300,7 @@ def write( def write_partitioned( table: Union[Table, PartitionedTable], destination_dir: str, - table_definition: Optional[Union[Dict[str, DType], List[Column]]] = None, + table_definition: Optional[TableDefinitionLike] = None, col_instructions: Optional[List[ColumnInstruction]] = None, compression_codec_name: Optional[str] = None, max_dictionary_keys: Optional[int] = None, @@ -388,7 +386,7 @@ def write_partitioned( def batch_write( tables: List[Table], paths: List[str], - table_definition: Optional[Union[Dict[str, DType], List[Column]]] = None, + table_definition: Optional[TableDefinitionLike] = None, col_instructions: Optional[List[ColumnInstruction]] = None, compression_codec_name: Optional[str] = None, max_dictionary_keys: Optional[int] = None, diff --git a/py/server/deephaven/stream/kafka/consumer.py b/py/server/deephaven/stream/kafka/consumer.py index 01d33e7463b..abfe1a4f818 100644 --- a/py/server/deephaven/stream/kafka/consumer.py +++ b/py/server/deephaven/stream/kafka/consumer.py @@ -9,11 +9,11 @@ from deephaven import dtypes from deephaven._wrapper import JObjectWrapper -from deephaven.column import Column +from deephaven.column import col_def from deephaven.dherror import DHError from deephaven.dtypes import DType from deephaven.jcompat import j_hashmap, j_properties, j_array_list -from deephaven.table import Table, PartitionedTable +from deephaven.table import Table, TableDefinition, TableDefinitionLike, PartitionedTable _JKafkaTools = jpy.get_type("io.deephaven.kafka.KafkaTools") _JKafkaTools_Consume = jpy.get_type("io.deephaven.kafka.KafkaTools$Consume") @@ -427,13 +427,13 @@ def avro_spec( raise DHError(e, "failed to create a Kafka key/value spec") from e -def json_spec(col_defs: Union[Dict[str, DType], List[Tuple[str, DType]]], mapping: Dict = None) -> KeyValueSpec: +def json_spec(col_defs: Union[TableDefinitionLike, List[Tuple[str, DType]]], mapping: Dict = None) -> KeyValueSpec: """Creates a spec for how to use JSON data when consuming a Kafka stream to a Deephaven table. Args: - col_defs (Union[Dict[str, DType], List[Tuple[str, DType]]): the column definitions, either a map of column - names and Deephaven types, or a list of tuples with two elements, a string for column name and a Deephaven - type for column data type. + col_defs (Union[TableDefinitionLike, List[Tuple[str, DType]]): the table definition, preferably specified as + TableDefinitionLike. A list of tuples with two elements, a string for column name and a Deephaven type for + column data type also works, but is deprecated for removal. mapping (Dict): a dict mapping JSON fields to column names defined in the col_defs argument. Fields starting with a '/' character are interpreted as a JSON Pointer (see RFC 6901, ISSN: 2070-1721 for details, essentially nested fields are represented like "/parent/nested"). @@ -448,10 +448,20 @@ def json_spec(col_defs: Union[Dict[str, DType], List[Tuple[str, DType]]], mappin DHError """ try: - if isinstance(col_defs, dict): - col_defs = [Column(k, v).j_column_definition for k, v in col_defs.items()] + try: + table_def = TableDefinition(col_defs) + except DHError: + table_def = None + + if table_def: + col_defs = [col.j_column_definition for col in table_def.values()] else: - col_defs = [Column(*t).j_column_definition for t in col_defs] + warn( + 'json_spec col_defs for List[Tuple[str, DType]] is deprecated for removal, prefer TableDefinitionLike', + DeprecationWarning, + stacklevel=2, + ) + col_defs = [col_def(*t).j_column_definition for t in col_defs] if mapping is None: return KeyValueSpec(j_spec=_JKafkaTools_Consume.jsonSpec(col_defs)) diff --git a/py/server/deephaven/stream/table_publisher.py b/py/server/deephaven/stream/table_publisher.py index a6c65f47885..57ead700e79 100644 --- a/py/server/deephaven/stream/table_publisher.py +++ b/py/server/deephaven/stream/table_publisher.py @@ -8,11 +8,9 @@ from typing import Callable, Dict, Optional, Tuple, Union, List from deephaven._wrapper import JObjectWrapper -from deephaven.column import Column -from deephaven.dtypes import DType from deephaven.execution_context import get_exec_ctx -from deephaven.jcompat import j_lambda, j_runnable, j_table_definition -from deephaven.table import Table +from deephaven.jcompat import j_lambda, j_runnable +from deephaven.table import Table, TableDefinition, TableDefinitionLike from deephaven.update_graph import UpdateGraph _JTableDefinition = jpy.get_type("io.deephaven.engine.table.TableDefinition") @@ -75,7 +73,7 @@ def publish_failure(self, failure: Exception) -> None: def table_publisher( name: str, - col_defs: Union[Dict[str, DType], List[Column]], + col_defs: TableDefinitionLike, on_flush_callback: Optional[Callable[[TablePublisher], None]] = None, on_shutdown_callback: Optional[Callable[[], None]] = None, update_graph: Optional[UpdateGraph] = None, @@ -85,7 +83,7 @@ def table_publisher( Args: name (str): the name, used for logging - col_defs (Dict[str, DType]): the column definitions for the resulting blink table + col_defs (TableDefinitionLike): the table definition for the resulting blink table on_flush_callback (Optional[Callable[[TablePublisher], None]]): the on-flush callback, if present, is called once at the beginning of each update graph cycle. This is a pattern that allows publishers to add any data they may have been batching. Do note though, this blocks the update cycle from proceeding, so @@ -107,7 +105,7 @@ def adapt_callback(_table_publisher: jpy.JType): j_table_publisher = _JTablePublisher.of( name, - j_table_definition(col_defs), + TableDefinition(col_defs).j_table_definition, j_lambda(adapt_callback, _JConsumer, None) if on_flush_callback else None, j_runnable(on_shutdown_callback) if on_shutdown_callback else None, (update_graph or get_exec_ctx().update_graph).j_update_graph, diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index 26557db18a1..abb22e1031c 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -11,11 +11,13 @@ import inspect from enum import Enum from enum import auto +from functools import cached_property from typing import Any, Optional, Callable, Dict, Generator, Tuple, Literal -from typing import Sequence, List, Union, Protocol +from typing import Sequence, List, Union, Protocol, Mapping, Iterable import jpy import numpy as np +import sys from deephaven import DHError from deephaven import dtypes @@ -23,7 +25,7 @@ from deephaven._wrapper import JObjectWrapper from deephaven._wrapper import unwrap from deephaven.agg import Aggregation -from deephaven.column import Column, ColumnType +from deephaven.column import col_def, ColumnDefinition from deephaven.filters import Filter, and_, or_ from deephaven.jcompat import j_unary_operator, j_binary_operator, j_map_to_dict, j_hashmap from deephaven.jcompat import to_sequence, j_array_list @@ -407,19 +409,133 @@ def _sort_column(col, dir_): _JColumnName.of(col))) -def _td_to_columns(table_definition): - cols = [] - j_cols = table_definition.getColumnsArray() - for j_col in j_cols: - cols.append( - Column( - name=j_col.getName(), - data_type=dtypes.from_jtype(j_col.getDataType()), - component_type=dtypes.from_jtype(j_col.getComponentType()), - column_type=ColumnType(j_col.getColumnType()), +if sys.version_info >= (3, 10): + from typing import TypeAlias # novermin + + TableDefinitionLike: TypeAlias = Union[ + "TableDefinition", + Mapping[str, dtypes.DType], + Iterable[ColumnDefinition], + jpy.JType, + ] + """A Union representing objects that can be coerced into a TableDefinition.""" +else: + TableDefinitionLike = Union[ + "TableDefinition", + Mapping[str, dtypes.DType], + Iterable[ColumnDefinition], + jpy.JType, + ] + """A Union representing objects that can be coerced into a TableDefinition.""" + + +class TableDefinition(JObjectWrapper, Mapping): + """A Deephaven table definition, as a mapping from column name to ColumnDefinition.""" + + j_object_type = _JTableDefinition + + @staticmethod + def _to_j_table_definition(table_definition: TableDefinitionLike) -> jpy.JType: + if isinstance(table_definition, TableDefinition): + return table_definition.j_table_definition + if isinstance(table_definition, _JTableDefinition): + return table_definition + if isinstance(table_definition, Mapping): + for name in table_definition.keys(): + if not isinstance(name, str): + raise DHError( + f"Expected TableDefinitionLike Mapping to contain str keys, found type {type(name)}" + ) + for data_type in table_definition.values(): + if not isinstance(data_type, dtypes.DType): + raise DHError( + f"Expected TableDefinitionLike Mapping to contain DType values, found type {type(data_type)}" + ) + column_definitions = [ + col_def(name, data_type) for name, data_type in table_definition.items() + ] + elif isinstance(table_definition, Iterable): + for column_definition in table_definition: + if not isinstance(column_definition, ColumnDefinition): + raise DHError( + f"Expected TableDefinitionLike Iterable to contain ColumnDefinition values, found type {type(column_definition)}" + ) + column_definitions = table_definition + else: + raise DHError( + f"Unexpected TableDefinitionLike type: {type(table_definition)}" ) + return _JTableDefinition.of( + [col.j_column_definition for col in column_definitions] ) - return cols + + def __init__(self, table_definition: TableDefinitionLike): + """Construct a TableDefinition. + + Args: + table_definition (TableDefinitionLike): The table definition like object + + Returns: + A new TableDefinition + + Raises: + DHError + """ + self.j_table_definition = TableDefinition._to_j_table_definition( + table_definition + ) + + @property + def j_object(self) -> jpy.JType: + return self.j_table_definition + + @property + def table(self) -> Table: + """This table definition as a table.""" + return Table(_JTableTools.metaTable(self.j_table_definition)) + + def keys(self): + """The column names as a dictview.""" + return self._dict.keys() + + def items(self): + """The column name, column definition tuples as a dictview.""" + return self._dict.items() + + def values(self): + """The column definitions as a dictview.""" + return self._dict.values() + + @cached_property + def _dict(self) -> Dict[str, ColumnDefinition]: + return { + col.name: col + for col in [ + ColumnDefinition(j_col) + for j_col in self.j_table_definition.getColumnsArray() + ] + } + + def __getitem__(self, key) -> ColumnDefinition: + return self._dict[key] + + def __iter__(self): + return iter(self._dict) + + def __len__(self): + return len(self._dict) + + def __contains__(self, item): + return item in self._dict + + def __eq__(self, other): + return JObjectWrapper.__eq__(self, other) + + def __ne__(self, other): + return JObjectWrapper.__ne__(self, other) + + def __hash__(self): + return JObjectWrapper.__hash__(self) class Table(JObjectWrapper): @@ -435,11 +551,7 @@ def __init__(self, j_table: jpy.JType): self.j_table = jpy.cast(j_table, self.j_object_type) if self.j_table is None: raise DHError("j_table type is not io.deephaven.engine.table.Table") - self._definition = self.j_table.getDefinition() - self._schema = None - self._is_refreshing = None - self._update_graph = None - self._is_flat = None + self._definition = TableDefinition(self.j_table.getDefinition()) def __repr__(self): default_repr = super().__repr__() @@ -465,37 +577,37 @@ def size(self) -> int: @property def is_refreshing(self) -> bool: """Whether this table is refreshing.""" - if self._is_refreshing is None: - self._is_refreshing = self.j_table.isRefreshing() - return self._is_refreshing + return self.j_table.isRefreshing() @property def is_blink(self) -> bool: """Whether this table is a blink table.""" return _JBlinkTableTools.isBlink(self.j_table) - @property + @cached_property def update_graph(self) -> UpdateGraph: """The update graph of the table.""" - if self._update_graph is None: - self._update_graph = UpdateGraph(self.j_table.getUpdateGraph()) - return self._update_graph + return UpdateGraph(self.j_table.getUpdateGraph()) @property def is_flat(self) -> bool: """Whether this table is guaranteed to be flat, i.e. its row set will be from 0 to number of rows - 1.""" - if self._is_flat is None: - self._is_flat = self.j_table.isFlat() - return self._is_flat + return self.j_table.isFlat() @property - def columns(self) -> List[Column]: - """The column definitions of the table.""" - if self._schema: - return self._schema + def definition(self) -> TableDefinition: + """The table definition.""" + return self._definition + + @property + def column_names(self) -> List[str]: + """The column names of the table.""" + return list(self.definition.keys()) - self._schema = _td_to_columns(self._definition) - return self._schema + @property + def columns(self) -> List[ColumnDefinition]: + """The column definitions of the table.""" + return list(self.definition.values()) @property def meta_table(self) -> Table: @@ -2338,13 +2450,8 @@ def j_object(self) -> jpy.JType: def __init__(self, j_partitioned_table): self.j_partitioned_table = j_partitioned_table - self._schema = None + self._definition = None self._table = None - self._key_columns = None - self._unique_keys = None - self._constituent_column = None - self._constituent_changes_permitted = None - self._is_refreshing = None @classmethod def from_partitioned_table(cls, @@ -2352,18 +2459,18 @@ def from_partitioned_table(cls, key_cols: Union[str, List[str]] = None, unique_keys: bool = None, constituent_column: str = None, - constituent_table_columns: List[Column] = None, + constituent_table_columns: Optional[TableDefinitionLike] = None, constituent_changes_permitted: bool = None) -> PartitionedTable: """Creates a PartitionedTable from the provided underlying partitioned Table. - Note: key_cols, unique_keys, constituent_column, constituent_table_columns, + Note: key_cols, unique_keys, constituent_column, constituent_table_definition, constituent_changes_permitted must either be all None or all have values. When they are None, their values will be inferred as follows: | * key_cols: the names of all columns with a non-Table data type | * unique_keys: False | * constituent_column: the name of the first column with a Table data type - | * constituent_table_columns: the column definitions of the first cell (constituent table) in the constituent + | * constituent_table_definition: the table definitions of the first cell (constituent table) in the constituent column. Consequently, the constituent column can't be empty. | * constituent_changes_permitted: the value of table.is_refreshing @@ -2373,7 +2480,7 @@ def from_partitioned_table(cls, key_cols (Union[str, List[str]]): the key column name(s) of 'table' unique_keys (bool): whether the keys in 'table' are guaranteed to be unique constituent_column (str): the constituent column name in 'table' - constituent_table_columns (List[Column]): the column definitions of the constituent table + constituent_table_columns (Optional[TableDefinitionLike]): the table definitions of the constituent table constituent_changes_permitted (bool): whether the values of the constituent column can change Returns: @@ -2390,7 +2497,7 @@ def from_partitioned_table(cls, return PartitionedTable(j_partitioned_table=_JPartitionedTableFactory.of(table.j_table)) if all([arg is not None for arg in none_args]): - table_def = _JTableDefinition.of([col.j_column_definition for col in constituent_table_columns]) + table_def = TableDefinition(constituent_table_columns).j_table_definition j_partitioned_table = _JPartitionedTableFactory.of(table.j_table, j_array_list(to_sequence(key_cols)), unique_keys, @@ -2407,18 +2514,18 @@ def from_partitioned_table(cls, @classmethod def from_constituent_tables(cls, tables: List[Table], - constituent_table_columns: List[Column] = None) -> PartitionedTable: + constituent_table_columns: Optional[TableDefinitionLike] = None) -> PartitionedTable: """Creates a PartitionedTable with a single column named '__CONSTITUENT__' containing the provided constituent tables. The result PartitionedTable has no key columns, and both its unique_keys and constituent_changes_permitted - properties are set to False. When constituent_table_columns isn't provided, it will be set to the column + properties are set to False. When constituent_table_definition isn't provided, it will be set to the table definitions of the first table in the provided constituent tables. Args: tables (List[Table]): the constituent tables - constituent_table_columns (List[Column]): a list of column definitions compatible with all the constituent - tables, default is None + constituent_table_columns (Optional[TableDefinitionLike]): the table definition compatible with all the + constituent tables, default is None Returns: a PartitionedTable @@ -2430,37 +2537,31 @@ def from_constituent_tables(cls, if not constituent_table_columns: return PartitionedTable(j_partitioned_table=_JPartitionedTableFactory.ofTables(to_sequence(tables))) else: - table_def = _JTableDefinition.of([col.j_column_definition for col in constituent_table_columns]) + table_def = TableDefinition(constituent_table_columns).j_table_definition return PartitionedTable(j_partitioned_table=_JPartitionedTableFactory.ofTables(table_def, to_sequence(tables))) except Exception as e: raise DHError(e, "failed to create a PartitionedTable from constituent tables.") from e - @property + @cached_property def table(self) -> Table: """The underlying partitioned table.""" - if self._table is None: - self._table = Table(j_table=self.j_partitioned_table.table()) - return self._table + return Table(j_table=self.j_partitioned_table.table()) @property def update_graph(self) -> UpdateGraph: """The underlying partitioned table's update graph.""" return self.table.update_graph - @property + @cached_property def is_refreshing(self) -> bool: """Whether the underlying partitioned table is refreshing.""" - if self._is_refreshing is None: - self._is_refreshing = self.table.is_refreshing - return self._is_refreshing + return self.table.is_refreshing - @property + @cached_property def key_columns(self) -> List[str]: """The partition key column names.""" - if self._key_columns is None: - self._key_columns = list(self.j_partitioned_table.keyColumnNames().toArray()) - return self._key_columns + return list(self.j_partitioned_table.keyColumnNames().toArray()) def keys(self) -> Table: """Returns a Table containing all the keys of the underlying partitioned table.""" @@ -2469,32 +2570,31 @@ def keys(self) -> Table: else: return self.table.select_distinct(self.key_columns) - @property + @cached_property def unique_keys(self) -> bool: """Whether the keys in the underlying table must always be unique. If keys must be unique, one can expect that self.table.select_distinct(self.key_columns) and self.table.view(self.key_columns) operations always produce equivalent tables.""" - if self._unique_keys is None: - self._unique_keys = self.j_partitioned_table.uniqueKeys() - return self._unique_keys + return self.j_partitioned_table.uniqueKeys() - @property + @cached_property def constituent_column(self) -> str: """The name of the column containing constituent tables.""" - if self._constituent_column is None: - self._constituent_column = self.j_partitioned_table.constituentColumnName() - return self._constituent_column + return self.j_partitioned_table.constituentColumnName() + + @cached_property + def constituent_table_definition(self) -> TableDefinition: + """The table definitions for constituent tables. All constituent tables in a partitioned table have the + same table definitions.""" + return TableDefinition(self.j_partitioned_table.constituentDefinition()) @property - def constituent_table_columns(self) -> List[Column]: + def constituent_table_columns(self) -> List[ColumnDefinition]: """The column definitions for constituent tables. All constituent tables in a partitioned table have the same column definitions.""" - if not self._schema: - self._schema = _td_to_columns(self.j_partitioned_table.constituentDefinition()) + return list(self.constituent_table_definition.values()) - return self._schema - - @property + @cached_property def constituent_changes_permitted(self) -> bool: """Can the constituents of the underlying partitioned table change? Specifically, can the values of the constituent column change? @@ -2509,9 +2609,7 @@ def constituent_changes_permitted(self) -> bool: if the underlying partitioned table is refreshing. Also note that the underlying partitioned table must be refreshing if it contains any refreshing constituents. """ - if self._constituent_changes_permitted is None: - self._constituent_changes_permitted = self.j_partitioned_table.constituentChangesPermitted() - return self._constituent_changes_permitted + return self.j_partitioned_table.constituentChangesPermitted() def merge(self) -> Table: """Makes a new Table that contains all the rows from all the constituent tables. In the merged result, diff --git a/py/server/deephaven/table_factory.py b/py/server/deephaven/table_factory.py index 7efff4b534c..02316928573 100644 --- a/py/server/deephaven/table_factory.py +++ b/py/server/deephaven/table_factory.py @@ -5,7 +5,7 @@ """ This module provides various ways to make a Deephaven table. """ import datetime -from typing import Callable, List, Dict, Any, Union, Sequence, Tuple, Mapping +from typing import Callable, List, Dict, Any, Union, Sequence, Tuple, Mapping, Optional import jpy import numpy as np @@ -13,11 +13,11 @@ from deephaven import execution_context, DHError, time from deephaven._wrapper import JObjectWrapper -from deephaven.column import InputColumn, Column +from deephaven.column import InputColumn from deephaven.dtypes import DType, Duration, Instant from deephaven.execution_context import ExecutionContext from deephaven.jcompat import j_lambda, j_list_to_list, to_sequence -from deephaven.table import Table +from deephaven.table import Table, TableDefinition, TableDefinitionLike from deephaven.update_graph import auto_locking_ctx _JTableFactory = jpy.get_type("io.deephaven.engine.table.TableFactory") @@ -285,7 +285,7 @@ def value_names(self) -> List[str]: return j_list_to_list(self.j_input_table.getValueNames()) -def input_table(col_defs: Dict[str, DType] = None, init_table: Table = None, +def input_table(col_defs: Optional[TableDefinitionLike] = None, init_table: Table = None, key_cols: Union[str, Sequence[str]] = None) -> InputTable: """Creates an in-memory InputTable from either column definitions or an initial table. When key columns are provided, the InputTable will be keyed, otherwise it will be append-only. @@ -298,7 +298,7 @@ def input_table(col_defs: Dict[str, DType] = None, init_table: Table = None, The keyed input table has keys for each row and supports addition/deletion/modification of rows by the keys. Args: - col_defs (Dict[str, DType]): the column definitions + col_defs (Optional[TableDefinitionLike]): the table definition init_table (Table): the initial table key_cols (Union[str, Sequence[str]): the name(s) of the key column(s) @@ -316,8 +316,7 @@ def input_table(col_defs: Dict[str, DType] = None, init_table: Table = None, raise ValueError("both column definitions and init table are provided.") if col_defs: - j_arg_1 = _JTableDefinition.of( - [Column(name=n, data_type=t).j_column_definition for n, t in col_defs.items()]) + j_arg_1 = TableDefinition(col_defs).j_table_definition else: j_arg_1 = init_table.j_table diff --git a/py/server/tests/test_barrage.py b/py/server/tests/test_barrage.py index 28c0a49b6aa..7049d2566ff 100644 --- a/py/server/tests/test_barrage.py +++ b/py/server/tests/test_barrage.py @@ -76,7 +76,7 @@ def test_subscribe(self): session = barrage_session(host="localhost", port=10000, auth_type="Anonymous") t = session.subscribe(ticket=self.shared_ticket.bytes) self.assertEqual(t.size, 1000) - self.assertEqual(len(t.columns), 2) + self.assertEqual(len(t.definition), 2) sp = t.snapshot() self.assertEqual(sp.size, 1000) t1 = t.update("Z = X + Y") @@ -119,7 +119,7 @@ def test_snapshot(self): session = barrage_session(host="localhost", port=10000, auth_type="Anonymous") t = session.snapshot(self.shared_ticket.bytes) self.assertEqual(t.size, 1000) - self.assertEqual(len(t.columns), 2) + self.assertEqual(len(t.definition), 2) t1 = t.update("Z = X + Y") self.assertEqual(t1.size, 1000) t2 = session.snapshot(self.shared_ticket.bytes) diff --git a/py/server/tests/test_column.py b/py/server/tests/test_column.py index 57c94d53fc1..6531b6c3fa3 100644 --- a/py/server/tests/test_column.py +++ b/py/server/tests/test_column.py @@ -12,7 +12,7 @@ from deephaven import DHError, dtypes, new_table, time as dhtime from deephaven import empty_table from deephaven.column import byte_col, char_col, short_col, bool_col, int_col, long_col, float_col, double_col, \ - string_col, datetime_col, jobj_col, ColumnType + string_col, datetime_col, jobj_col, ColumnType, col_def from deephaven.constants import MAX_BYTE, MAX_SHORT, MAX_INT, MAX_LONG from deephaven.jcompat import j_array_list from tests.testbase import BaseTestCase @@ -136,7 +136,8 @@ def test_datetime_col(self): inst = dhtime.to_j_instant(round(time.time())) dt = datetime.datetime.now() _ = datetime_col(name="Datetime", data=[inst, dt, None]) - self.assertEqual(_.data_type, dtypes.Instant) + self.assertEqual(_._column_definition.name, "Datetime") + self.assertEqual(_._column_definition.data_type, dtypes.Instant) ts = pd.Timestamp(dt) np_dt = np.datetime64(dt) @@ -144,17 +145,46 @@ def test_datetime_col(self): # test if we can convert to numpy datetime64 array np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64) _ = datetime_col(name="Datetime", data=data) - self.assertEqual(_.data_type, dtypes.Instant) + self.assertEqual(_._column_definition.name, "Datetime") + self.assertEqual(_._column_definition.data_type, dtypes.Instant) data = np.array(['1970-01-01T00:00:00.000-07:00', '2020-01-01T01:00:00.000+07:00']) np.array([pd.Timestamp(str(dt)).to_numpy() for dt in data], dtype=np.datetime64) _ = datetime_col(name="Datetime", data=data) - self.assertEqual(_.data_type, dtypes.Instant) + self.assertEqual(_._column_definition.name, "Datetime") + self.assertEqual(_._column_definition.data_type, dtypes.Instant) data = np.array([1, -1]) data = data.astype(np.int64) _ = datetime_col(name="Datetime", data=data) - self.assertEqual(_.data_type, dtypes.Instant) + self.assertEqual(_._column_definition.name, "Datetime") + self.assertEqual(_._column_definition.data_type, dtypes.Instant) + + def test_col_def_simple(self): + foo_def = col_def("Foo", dtypes.int32) + self.assertEquals(foo_def.name, "Foo") + self.assertEquals(foo_def.data_type, dtypes.int32) + self.assertEquals(foo_def.component_type, None) + self.assertEquals(foo_def.column_type, ColumnType.NORMAL) + + def test_col_def_array(self): + foo_def = col_def("Foo", dtypes.int32_array) + self.assertEquals(foo_def.name, "Foo") + self.assertEquals(foo_def.data_type, dtypes.int32_array) + self.assertEquals(foo_def.component_type, dtypes.int32) + self.assertEquals(foo_def.column_type, ColumnType.NORMAL) + + def test_col_def_partitioning(self): + foo_def = col_def("Foo", dtypes.string, column_type=ColumnType.PARTITIONING) + self.assertEquals(foo_def.name, "Foo") + self.assertEquals(foo_def.data_type, dtypes.string) + self.assertEquals(foo_def.component_type, None) + self.assertEquals(foo_def.column_type, ColumnType.PARTITIONING) + + def test_col_def_invalid_component_type(self): + with self.assertRaises(DHError): + col_def("Foo", dtypes.int32_array, component_type=dtypes.int64) + @dataclass class CustomClass: diff --git a/py/server/tests/test_csv.py b/py/server/tests/test_csv.py index 3de09e9a570..88e291092cd 100644 --- a/py/server/tests/test_csv.py +++ b/py/server/tests/test_csv.py @@ -20,8 +20,7 @@ def test_read_header(self): col_types = [dtypes.string, dtypes.long, dtypes.float64] table_header = {k: v for k, v in zip(col_names, col_types)} t = read_csv('tests/data/test_csv.csv', header=table_header) - t_col_names = [col.name for col in t.columns] - self.assertEqual(col_names, t_col_names) + self.assertEqual(col_names, t.column_names) def test_read_error_col_type(self): col_names = ["Strings", "Longs", "Floats"] @@ -44,9 +43,9 @@ def test_read_error_quote(self): def test_write(self): t = read_csv("tests/data/small_sample.csv") write_csv(t, "./test_write.csv") - t_cols = [col.name for col in t.columns] + t_cols = t.column_names t = read_csv("./test_write.csv") - self.assertEqual(t_cols, [col.name for col in t.columns]) + self.assertEqual(t_cols, t.column_names) col_names = ["Strings", "Longs", "Floats"] col_types = [dtypes.string, dtypes.long, dtypes.float64] @@ -54,7 +53,7 @@ def test_write(self): t = read_csv('tests/data/test_csv.csv', header=table_header) write_csv(t, "./test_write.csv", cols=col_names) t = read_csv('./test_write.csv') - self.assertEqual(col_names, [c.name for c in t.columns]) + self.assertEqual(col_names, t.column_names) import os os.remove("./test_write.csv") diff --git a/py/server/tests/test_data_index.py b/py/server/tests/test_data_index.py index 5b3aad01391..14b21407f7c 100644 --- a/py/server/tests/test_data_index.py +++ b/py/server/tests/test_data_index.py @@ -47,7 +47,7 @@ def test_keys(self): self.assertEqual(["X", "Y"], self.data_index.keys) def test_backing_table(self): - self.assertEqual(3, len(self.data_index.table.columns)) + self.assertEqual(3, len(self.data_index.table.definition)) self.assertEqual(10, self.data_index.table.size) di = data_index(self.data_index.table, self.data_index.keys[0:1]) self.assertEqual(1, len(di.keys)) diff --git a/py/server/tests/test_dbc.py b/py/server/tests/test_dbc.py index 31868028e08..fd3cba4fbf2 100644 --- a/py/server/tests/test_dbc.py +++ b/py/server/tests/test_dbc.py @@ -50,7 +50,7 @@ def test_read_sql_connectorx(self): query = "SELECT t_ts, t_id, t_instrument, t_exchange, t_price, t_size FROM CRYPTO_TRADES LIMIT 10" postgres_url = "postgresql://test:test@postgres:5432/test" dh_table = read_sql(conn=postgres_url, query=query) - self.assertEqual(len(dh_table.columns), 6) + self.assertEqual(len(dh_table.definition), 6) self.assertEqual(dh_table.size, 10) with self.assertRaises(DHError) as cm: @@ -63,13 +63,13 @@ def test_read_sql(self): with self.subTest("odbc"): connection_string = 'Driver={PostgreSQL};Server=postgres;Port=5432;Database=test;Uid=test;Pwd=test;' dh_table = read_sql(conn=connection_string, query=query, driver="odbc") - self.assertEqual(len(dh_table.columns), 6) + self.assertEqual(len(dh_table.definition), 6) self.assertEqual(dh_table.size, 10) with self.subTest("adbc"): uri = "postgresql://postgres:5432/test?user=test&password=test" dh_table = read_sql(conn=uri, query=query, driver="adbc") - self.assertEqual(len(dh_table.columns), 6) + self.assertEqual(len(dh_table.definition), 6) self.assertEqual(dh_table.size, 10) if turbodbc_installed(): @@ -79,7 +79,7 @@ def test_read_sql(self): connection_string = "Driver={PostgreSQL};Server=postgres;Port=5432;Database=test;Uid=test;Pwd=test;" with turbodbc.connect(connection_string=connection_string) as conn: dh_table = read_sql(conn=conn, query=query, driver="odbc") - self.assertEqual(len(dh_table.columns), 6) + self.assertEqual(len(dh_table.definition), 6) self.assertEqual(dh_table.size, 10) with self.subTest("adbc-connection"): @@ -87,7 +87,7 @@ def test_read_sql(self): uri = "postgresql://postgres:5432/test?user=test&password=test" with adbc_driver_postgresql.dbapi.connect(uri) as conn: dh_table = read_sql(conn=conn, query=query, driver="adbc") - self.assertEqual(len(dh_table.columns), 6) + self.assertEqual(len(dh_table.definition), 6) self.assertEqual(dh_table.size, 10) with self.assertRaises(DHError) as cm: diff --git a/py/server/tests/test_experiments.py b/py/server/tests/test_experiments.py index 6de28871a98..fd04cce0c65 100644 --- a/py/server/tests/test_experiments.py +++ b/py/server/tests/test_experiments.py @@ -31,13 +31,13 @@ def test_full_outer_join(self): rt = full_outer_join(t1, t2, on="a = c") self.assertTrue(rt.is_refreshing) self.wait_ticking_table_update(rt, row_count=100, timeout=5) - self.assertEqual(len(rt.columns), len(t1.columns) + len(t2.columns)) + self.assertEqual(len(rt.definition), len(t1.definition) + len(t2.definition)) with self.subTest("full outer join with no matching keys"): t1 = empty_table(2).update(["X = i", "a = i"]) rt = full_outer_join(self.test_table, t1, joins=["Y = a"]) self.assertEqual(rt.size, t1.size * self.test_table.size) - self.assertEqual(len(rt.columns), 1 + len(self.test_table.columns)) + self.assertEqual(len(rt.definition), 1 + len(self.test_table.definition)) with self.subTest("Conflicting column names"): with self.assertRaises(DHError) as cm: @@ -52,13 +52,13 @@ def test_left_outer_join(self): rt = left_outer_join(t1, t2, on="a = c") self.assertTrue(rt.is_refreshing) self.wait_ticking_table_update(rt, row_count=100, timeout=5) - self.assertEqual(len(rt.columns), len(t1.columns) + len(t2.columns)) + self.assertEqual(len(rt.definition), len(t1.definition) + len(t2.definition)) with self.subTest("left outer join with no matching keys"): t1 = empty_table(2).update(["X = i", "a = i"]) rt = left_outer_join(self.test_table, t1, joins=["Y = a"]) self.assertEqual(rt.size, t1.size * self.test_table.size) - self.assertEqual(len(rt.columns), 1 + len(self.test_table.columns)) + self.assertEqual(len(rt.definition), 1 + len(self.test_table.definition)) with self.subTest("Conflicting column names"): with self.assertRaises(DHError) as cm: diff --git a/py/server/tests/test_iceberg.py b/py/server/tests/test_iceberg.py index 62ba31e6636..8934299b74d 100644 --- a/py/server/tests/test_iceberg.py +++ b/py/server/tests/test_iceberg.py @@ -4,7 +4,7 @@ import jpy from deephaven import dtypes -from deephaven.column import Column, ColumnType +from deephaven.column import col_def, ColumnType from tests.testbase import BaseTestCase from deephaven.experimental import s3, iceberg @@ -60,12 +60,10 @@ def test_instruction_create_with_table_definition_dict(self): def test_instruction_create_with_table_definition_list(self): table_def=[ - Column( - "Partition", dtypes.int32, column_type=ColumnType.PARTITIONING - ), - Column("x", dtypes.int32), - Column("y", dtypes.double), - Column("z", dtypes.double), + col_def("Partition", dtypes.int32, column_type=ColumnType.PARTITIONING), + col_def("x", dtypes.int32), + col_def("y", dtypes.double), + col_def("z", dtypes.double), ] iceberg_instructions = iceberg.IcebergInstructions(table_definition=table_def) diff --git a/py/server/tests/test_numpy.py b/py/server/tests/test_numpy.py index 725e69602f1..1c935ed04f7 100644 --- a/py/server/tests/test_numpy.py +++ b/py/server/tests/test_numpy.py @@ -71,14 +71,14 @@ def tearDown(self) -> None: super().tearDown() def test_to_numpy(self): - for col in self.test_table.columns: - with self.subTest(f"test single column to numpy- {col.name}"): - np_array = to_numpy(self.test_table, [col.name]) + for col_name in self.test_table.definition: + with self.subTest(f"test single column to numpy- {col_name}"): + np_array = to_numpy(self.test_table, [col_name]) self.assertEqual((2, 1), np_array.shape) - np.array_equal(np_array, self.np_array_dict[col.name]) + np.array_equal(np_array, self.np_array_dict[col_name]) try: - to_numpy(self.test_table, [col.name for col in self.test_table.columns]) + to_numpy(self.test_table, self.test_table.column_names) except DHError as e: self.assertIn("same data type", e.root_cause) @@ -90,17 +90,17 @@ def test_to_numpy(self): float_col(name="Float3", data=[1111.01111, -1111.01111]), float_col(name="Float4", data=[11111.011111, -11111.011111])] tmp_table = new_table(cols=input_cols) - np_array = to_numpy(tmp_table, [col.name for col in tmp_table.columns]) + np_array = to_numpy(tmp_table, tmp_table.column_names) self.assertEqual((2, 5), np_array.shape) def test_to_numpy_remap(self): - for col in self.test_table.columns: - with self.subTest(f"test single column to numpy - {col.name}"): - np_array = to_numpy(self.test_table, [col.name]) + for col_name in self.test_table.definition: + with self.subTest(f"test single column to numpy - {col_name}"): + np_array = to_numpy(self.test_table, [col_name]) self.assertEqual((2, 1), np_array.shape) try: - to_numpy(self.test_table, [col.name for col in self.test_table.columns]) + to_numpy(self.test_table, self.test_table.column_names) except DHError as e: self.assertIn("same data type", e.root_cause) @@ -140,12 +140,12 @@ def test_to_table(self): float_col(name="Float3", data=[1111.01111, -1111.01111]), float_col(name="Float4", data=[11111.011111, -11111.011111])] tmp_table = new_table(cols=input_cols) - np_array = to_numpy(tmp_table, [col.name for col in tmp_table.columns]) - tmp_table2 = to_table(np_array, [col.name for col in tmp_table.columns]) + np_array = to_numpy(tmp_table, tmp_table.column_names) + tmp_table2 = to_table(np_array, tmp_table.column_names) self.assert_table_equals(tmp_table2, tmp_table) with self.assertRaises(DHError) as cm: - tmp_table3 = to_table(np_array[:, [0, 1, 3]], [col.name for col in tmp_table.columns]) + tmp_table3 = to_table(np_array[:, [0, 1, 3]], tmp_table.column_names) self.assertIn("doesn't match", cm.exception.root_cause) def get_resource_path(self, resource_path) -> str: diff --git a/py/server/tests/test_parquet.py b/py/server/tests/test_parquet.py index 1a5d3b3e31c..50c8cf6f68e 100644 --- a/py/server/tests/test_parquet.py +++ b/py/server/tests/test_parquet.py @@ -14,7 +14,7 @@ from deephaven import DHError, empty_table, dtypes, new_table from deephaven import arrow as dharrow -from deephaven.column import InputColumn, Column, ColumnType, string_col, int_col, char_col, long_col, short_col +from deephaven.column import InputColumn, ColumnType, col_def, string_col, int_col, char_col, long_col, short_col from deephaven.pandas import to_pandas, to_table from deephaven.parquet import (write, batch_write, read, delete, ColumnInstruction, ParquetFileLayout, write_partitioned) @@ -597,12 +597,10 @@ def test_read_kv_partitioned(self): actual = read( kv_dir, table_definition=[ - Column( - "Partition", dtypes.int32, column_type=ColumnType.PARTITIONING - ), - Column("x", dtypes.int32), - Column("y", dtypes.double), - Column("z", dtypes.double), + col_def("Partition", dtypes.int32, column_type=ColumnType.PARTITIONING), + col_def("x", dtypes.int32), + col_def("y", dtypes.double), + col_def("z", dtypes.double), ], file_layout=ParquetFileLayout.KV_PARTITIONED, ) @@ -655,7 +653,7 @@ def test_write_partitioned_data(self): shutil.rmtree(root_dir) def verify_table_from_disk(table): - self.assertTrue(len(table.columns)) + self.assertTrue(len(table.definition)) self.assertTrue(table.columns[0].name == "X") self.assertTrue(table.columns[0].column_type == ColumnType.PARTITIONING) self.assert_table_equals(table.select().sort(["X", "Y"]), source.sort(["X", "Y"])) @@ -696,9 +694,9 @@ def verify_file_names(): shutil.rmtree(root_dir) table_definition = [ - Column("X", dtypes.string, column_type=ColumnType.PARTITIONING), - Column("Y", dtypes.string), - Column("Number", dtypes.int32) + col_def("X", dtypes.string, column_type=ColumnType.PARTITIONING), + col_def("Y", dtypes.string), + col_def("Number", dtypes.int32) ] write_partitioned(source, table_definition=table_definition, destination_dir=root_dir, base_name=base_name, max_dictionary_keys=max_dictionary_keys) diff --git a/py/server/tests/test_partitioned_table.py b/py/server/tests/test_partitioned_table.py index 8da63d726a8..3059a6c40fb 100644 --- a/py/server/tests/test_partitioned_table.py +++ b/py/server/tests/test_partitioned_table.py @@ -65,6 +65,9 @@ def test_constituent_change_permitted(self): def test_constituent_table_columns(self): self.assertEqual(self.test_table.columns, self.partitioned_table.constituent_table_columns) + def test_constituent_table_definition(self): + self.assertEqual(self.test_table.definition, self.partitioned_table.constituent_table_definition) + def test_merge(self): t = self.partitioned_table.merge() self.assert_table_equals(t, self.test_table) @@ -188,7 +191,7 @@ def test_from_partitioned_table(self): key_cols="Y", unique_keys=True, constituent_column="aggPartition", - constituent_table_columns=test_table.columns, + constituent_table_columns=test_table.definition, constituent_changes_permitted=True, ) self.assertEqual(pt.key_columns, pt1.key_columns) @@ -201,7 +204,7 @@ def test_from_partitioned_table(self): key_cols="Y", unique_keys=True, constituent_column="Non-existing", - constituent_table_columns=test_table.columns, + constituent_table_columns=test_table.definition, constituent_changes_permitted=True, ) self.assertIn("no column named", str(cm.exception)) @@ -222,7 +225,7 @@ def test_from_constituent_tables(self): self.assertIn("IncompatibleTableDefinitionException", str(cm.exception)) with self.subTest("Compatible table definition"): - pt = PartitionedTable.from_constituent_tables([test_table, test_table1, test_table3], test_table.columns) + pt = PartitionedTable.from_constituent_tables([test_table, test_table1, test_table3], test_table.definition) def test_keys(self): keys_table = self.partitioned_table.keys() diff --git a/py/server/tests/test_pt_proxy.py b/py/server/tests/test_pt_proxy.py index 982a582fa6f..5cbe973dc0c 100644 --- a/py/server/tests/test_pt_proxy.py +++ b/py/server/tests/test_pt_proxy.py @@ -127,7 +127,7 @@ def test_USV(self): result_pt_proxy = op( self.pt_proxy, formulas=["a", "c", "Sum = a + b + c + d"]) for rct, ct in zip(result_pt_proxy.target.constituent_tables, self.pt_proxy.target.constituent_tables): - self.assertTrue(len(rct.columns) >= 3) + self.assertTrue(len(rct.definition) >= 3) self.assertLessEqual(rct.size, ct.size) def test_select_distinct(self): @@ -144,7 +144,7 @@ def test_natural_join(self): right_table = self.test_table.drop_columns(["b", "c"]).head(5) joined_pt_proxy = pt_proxy.natural_join(right_table, on="a", joins=["d", "e"]) for ct in joined_pt_proxy.target.constituent_tables: - self.assertEqual(len(ct.columns), 5) + self.assertEqual(len(ct.definition), 5) with self.subTest("Join with another Proxy"): with self.assertRaises(DHError) as cm: @@ -163,7 +163,7 @@ def test_natural_join(self): right_proxy = self.test_table.drop_columns(["b", "d"]).partition_by("c").proxy() joined_pt_proxy = pt_proxy.natural_join(right_proxy, on="a", joins="e") for ct in joined_pt_proxy.target.constituent_tables: - self.assertEqual(len(ct.columns), 4) + self.assertEqual(len(ct.definition), 4) def test_exact_join(self): with self.subTest("Join with a Table"): @@ -171,7 +171,7 @@ def test_exact_join(self): right_table = self.test_table.drop_columns(["b", "c"]).group_by('a') joined_pt_proxy = pt_proxy.exact_join(right_table, on="a", joins=["d", "e"]) for ct, jct in zip(pt_proxy.target.constituent_tables, joined_pt_proxy.target.constituent_tables): - self.assertEqual(len(jct.columns), 5) + self.assertEqual(len(jct.definition), 5) self.assertEqual(ct.size, jct.size) self.assertLessEqual(jct.size, right_table.size) @@ -180,7 +180,7 @@ def test_exact_join(self): right_proxy = self.test_table.drop_columns(["b", "d"]).partition_by("c").proxy() joined_pt_proxy = pt_proxy.exact_join(right_proxy, on="a", joins="e") for ct, jct in zip(pt_proxy.target.constituent_tables, joined_pt_proxy.target.constituent_tables): - self.assertEqual(len(jct.columns), 4) + self.assertEqual(len(jct.definition), 4) self.assertEqual(ct.size, jct.size) self.assertLessEqual(jct.size, right_table.size) @@ -247,7 +247,7 @@ def test_count_by(self): agg_pt_proxy = self.pt_proxy.count_by(col="cnt", by=["a"]) for gct, ct in zip(agg_pt_proxy.target.constituent_tables, self.pt_proxy.target.constituent_tables): self.assertLessEqual(gct.size, ct.size) - self.assertEqual(len(gct.columns), 2) + self.assertEqual(len(gct.definition), 2) def test_dedicated_agg(self): ops = [ @@ -268,7 +268,7 @@ def test_dedicated_agg(self): agg_pt_proxy = op(self.pt_proxy, by=["a", "b"]) for gct, ct in zip(agg_pt_proxy.target.constituent_tables, self.pt_proxy.target.constituent_tables): self.assertLessEqual(gct.size, ct.size) - self.assertEqual(len(gct.columns), len(ct.columns)) + self.assertEqual(len(gct.definition), len(ct.definition)) wops = [PartitionedTableProxy.weighted_avg_by, PartitionedTableProxy.weighted_sum_by, @@ -279,7 +279,7 @@ def test_dedicated_agg(self): agg_pt_proxy = wop(self.pt_proxy, wcol="e", by=["a", "b"]) for gct, ct in zip(agg_pt_proxy.target.constituent_tables, self.pt_proxy.target.constituent_tables): self.assertLessEqual(gct.size, ct.size) - self.assertEqual(len(gct.columns), len(ct.columns) - 1) + self.assertEqual(len(gct.definition), len(ct.definition) - 1) def test_agg_by(self): aggs = [ @@ -295,7 +295,7 @@ def test_agg_by(self): agg_pt_proxy = self.pt_proxy.agg_by(aggs=aggs, by=["a"]) for gct, ct in zip(agg_pt_proxy.target.constituent_tables, self.pt_proxy.target.constituent_tables): self.assertLessEqual(gct.size, ct.size) - self.assertEqual(len(gct.columns), 8) + self.assertEqual(len(gct.definition), 8) def test_agg_all_by(self): aggs = [ diff --git a/py/server/tests/test_table.py b/py/server/tests/test_table.py index 6b7ecf2168c..06e0ab2355a 100644 --- a/py/server/tests/test_table.py +++ b/py/server/tests/test_table.py @@ -14,7 +14,7 @@ from deephaven.html import to_html from deephaven.jcompat import j_hashmap from deephaven.pandas import to_pandas -from deephaven.table import Table, SearchDisplayMode, table_diff +from deephaven.table import Table, TableDefinition, SearchDisplayMode, table_diff from tests.testbase import BaseTestCase, table_equals @@ -84,9 +84,19 @@ def test_eq(self): t = self.test_table.where(["a > 500"]) self.assertNotEqual(t, self.test_table) + def test_definition(self): + expected = TableDefinition({ + "a": dtypes.int32, + "b": dtypes.int32, + "c": dtypes.int32, + "d": dtypes.int32, + "e": dtypes.int32 + }) + self.assertEquals(expected, self.test_table.definition) + def test_meta_table(self): t = self.test_table.meta_table - self.assertEqual(len(self.test_table.columns), t.size) + self.assertEqual(len(self.test_table.definition), t.size) def test_coalesce(self): t = self.test_table.update_view(["A = a * b"]) @@ -100,45 +110,45 @@ def test_flatten(self): self.assertTrue(ct.is_flat) def test_drop_columns(self): - column_names = [f.name for f in self.test_table.columns] + column_names = self.test_table.column_names result_table = self.test_table.drop_columns(cols=column_names[:-1]) - self.assertEqual(1, len(result_table.columns)) + self.assertEqual(1, len(result_table.definition)) result_table = self.test_table.drop_columns(cols=column_names[-1]) - self.assertEqual(1, len(self.test_table.columns) - len(result_table.columns)) + self.assertEqual(1, len(self.test_table.definition) - len(result_table.definition)) def test_move_columns(self): - column_names = [f.name for f in self.test_table.columns] + column_names = self.test_table.column_names cols_to_move = column_names[::2] with self.subTest("move-columns"): result_table = self.test_table.move_columns(1, cols_to_move) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual(cols_to_move, result_cols[1: len(cols_to_move) + 1]) with self.subTest("move-columns-up"): result_table = self.test_table.move_columns_up(cols_to_move) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual(cols_to_move, result_cols[: len(cols_to_move)]) with self.subTest("move-columns-down"): result_table = self.test_table.move_columns_down(cols_to_move) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual(cols_to_move, result_cols[-len(cols_to_move):]) cols_to_move = column_names[-1] with self.subTest("move-column"): result_table = self.test_table.move_columns(1, cols_to_move) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual([cols_to_move], result_cols[1: len(cols_to_move) + 1]) with self.subTest("move-column-up"): result_table = self.test_table.move_columns_up(cols_to_move) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual([cols_to_move], result_cols[: len(cols_to_move)]) with self.subTest("move-column-down"): result_table = self.test_table.move_columns_down(cols_to_move) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual([cols_to_move], result_cols[-len(cols_to_move):]) def test_rename_columns(self): @@ -147,10 +157,10 @@ def test_rename_columns(self): ] new_names = [cn.split("=")[0].strip() for cn in cols_to_rename] result_table = self.test_table.rename_columns(cols_to_rename) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual(new_names, result_cols[::2]) result_table = self.test_table.rename_columns(cols_to_rename[0]) - result_cols = [f.name for f in result_table.columns] + result_cols = result_table.column_names self.assertEqual(new_names[0], result_cols[::2][0]) def test_update_error(self): @@ -174,14 +184,14 @@ def test_USV(self): result_table = op( self.test_table, formulas=["a", "c", "Sum = a + b + c + d"]) self.assertIsNotNone(result_table) - self.assertTrue(len(result_table.columns) >= 3) + self.assertTrue(len(result_table.definition) >= 3) self.assertLessEqual(result_table.size, self.test_table.size) for op in ops: with self.subTest(op=op): result_table = op(self.test_table, formulas="Sum = a + b + c + d") self.assertIsNotNone(result_table) - self.assertTrue(len(result_table.columns) >= 1) + self.assertTrue(len(result_table.definition) >= 1) self.assertLessEqual(result_table.size, self.test_table.size) def test_select_distinct(self): @@ -430,10 +440,10 @@ def test_dedicated_agg(self): for wop in wops: with self.subTest(wop): result_table = wop(self.test_table, wcol='e', by=["a", "b"]) - self.assertEqual(len(result_table.columns), len(self.test_table.columns) - 1) + self.assertEqual(len(result_table.definition), len(self.test_table.definition) - 1) result_table = wop(self.test_table, wcol='e') - self.assertEqual(len(result_table.columns), len(self.test_table.columns) - 1) + self.assertEqual(len(result_table.definition), len(self.test_table.definition) - 1) def test_count_by(self): num_distinct_a = self.test_table.select_distinct(formulas=["a"]).size @@ -530,26 +540,26 @@ def test_snapshot_when(self): snapshot = self.test_table.snapshot_when(t) self.wait_ticking_table_update(snapshot, row_count=1, timeout=5) self.assertEqual(self.test_table.size, snapshot.size) - self.assertEqual(len(t.columns) + len(self.test_table.columns), len(snapshot.columns)) + self.assertEqual(len(t.definition) + len(self.test_table.definition), len(snapshot.definition)) with self.subTest("initial=True"): snapshot = self.test_table.snapshot_when(t, initial=True) self.assertEqual(self.test_table.size, snapshot.size) - self.assertEqual(len(t.columns) + len(self.test_table.columns), len(snapshot.columns)) + self.assertEqual(len(t.definition) + len(self.test_table.definition), len(snapshot.definition)) with self.subTest("stamp_cols=\"X\""): snapshot = self.test_table.snapshot_when(t, stamp_cols="X") - self.assertEqual(len(snapshot.columns), len(self.test_table.columns) + 1) + self.assertEqual(len(snapshot.definition), len(self.test_table.definition) + 1) with self.subTest("stamp_cols=[\"X\", \"Y\"]"): snapshot = self.test_table.snapshot_when(t, stamp_cols=["X", "Y"]) - self.assertEqual(len(snapshot.columns), len(self.test_table.columns) + 2) + self.assertEqual(len(snapshot.definition), len(self.test_table.definition) + 2) def test_snapshot_when_with_history(self): t = time_table("PT00:00:01") snapshot_hist = self.test_table.snapshot_when(t, history=True) self.wait_ticking_table_update(snapshot_hist, row_count=1, timeout=5) - self.assertEqual(1 + len(self.test_table.columns), len(snapshot_hist.columns)) + self.assertEqual(1 + len(self.test_table.definition), len(snapshot_hist.definition)) self.assertEqual(self.test_table.size, snapshot_hist.size) t = time_table("PT0.1S").update("X = i % 2 == 0 ? i : i - 1").sort("X").tail(10) @@ -1020,7 +1030,7 @@ def test_range_join(self): right_table = self.test_table.select_distinct().sort("b").drop_columns("e") result_table = left_table.range_join(right_table, on=["a = a", "c < b < e"], aggs=aggs) self.assertEqual(result_table.size, left_table.size) - self.assertEqual(len(result_table.columns), len(left_table.columns) + len(aggs)) + self.assertEqual(len(result_table.definition), len(left_table.definition) + len(aggs)) with self.assertRaises(DHError): time_table("PT00:00:00.001").update("a = i").range_join(right_table, on=["a = a", "a < b < c"], aggs=aggs) diff --git a/py/server/tests/test_table_definition.py b/py/server/tests/test_table_definition.py new file mode 100644 index 00000000000..d4d2cd34f86 --- /dev/null +++ b/py/server/tests/test_table_definition.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +# +import unittest +from typing import Mapping +from deephaven import dtypes, new_table, DHError +from deephaven.table import TableDefinition +from deephaven.column import col_def, string_col, bool_col +from tests.testbase import BaseTestCase + + +class TableDefinitionTestCase(BaseTestCase): + def setUp(self): + super().setUp() + self.test_definition = TableDefinition( + { + "Bool": dtypes.bool_, + "Char": dtypes.char, + "Short": dtypes.short, + "Int": dtypes.int32, + "Long": dtypes.int64, + "Float": dtypes.float32, + "Double": dtypes.float64, + "String": dtypes.string, + "Instant": dtypes.Instant, + } + ) + + def tearDown(self) -> None: + self.test_definition = None + super().tearDown() + + def test_is_mapping(self): + self.assertTrue(isinstance(self.test_definition, Mapping)) + + def test_length(self): + self.assertEquals(9, len(self.test_definition)) + + def test_contains(self): + self.assertTrue("Bool" in self.test_definition) + self.assertTrue("Char" in self.test_definition) + self.assertTrue("Short" in self.test_definition) + self.assertTrue("Int" in self.test_definition) + self.assertTrue("Long" in self.test_definition) + self.assertTrue("Float" in self.test_definition) + self.assertTrue("Double" in self.test_definition) + self.assertTrue("String" in self.test_definition) + self.assertTrue("Instant" in self.test_definition) + self.assertFalse("FooBarBaz" in self.test_definition) + + def test_getitem(self): + self.assertEquals(col_def("Bool", dtypes.bool_), self.test_definition["Bool"]) + self.assertEquals(col_def("Char", dtypes.char), self.test_definition["Char"]) + self.assertEquals(col_def("Short", dtypes.short), self.test_definition["Short"]) + self.assertEquals(col_def("Int", dtypes.int32), self.test_definition["Int"]) + self.assertEquals(col_def("Long", dtypes.int64), self.test_definition["Long"]) + self.assertEquals( + col_def("Float", dtypes.float32), self.test_definition["Float"] + ) + self.assertEquals( + col_def("Double", dtypes.float64), + self.test_definition["Double"], + ) + self.assertEquals( + col_def("String", dtypes.string), self.test_definition["String"] + ) + self.assertEquals( + col_def("Instant", dtypes.Instant), + self.test_definition["Instant"], + ) + with self.assertRaises(KeyError): + self.test_definition["FooBarBaz"] + + def test_get(self): + self.assertEquals( + col_def("Bool", dtypes.bool_), self.test_definition.get("Bool") + ) + self.assertEquals( + col_def("Char", dtypes.char), self.test_definition.get("Char") + ) + self.assertEquals( + col_def("Short", dtypes.short), + self.test_definition.get("Short"), + ) + self.assertEquals(col_def("Int", dtypes.int32), self.test_definition.get("Int")) + self.assertEquals( + col_def("Long", dtypes.int64), self.test_definition.get("Long") + ) + self.assertEquals( + col_def("Float", dtypes.float32), + self.test_definition.get("Float"), + ) + self.assertEquals( + col_def("Double", dtypes.float64), + self.test_definition.get("Double"), + ) + self.assertEquals( + col_def("String", dtypes.string), + self.test_definition.get("String"), + ) + self.assertEquals( + col_def("Instant", dtypes.Instant), + self.test_definition.get("Instant"), + ) + self.assertEquals(None, self.test_definition.get("FooBarBaz")) + + def test_iter(self): + self.assertEquals( + [ + "Bool", + "Char", + "Short", + "Int", + "Long", + "Float", + "Double", + "String", + "Instant", + ], + list(iter(self.test_definition)), + ) + + def test_keys(self): + self.assertEquals( + [ + "Bool", + "Char", + "Short", + "Int", + "Long", + "Float", + "Double", + "String", + "Instant", + ], + list(self.test_definition.keys()), + ) + + def test_values(self): + self.assertEquals( + [ + col_def("Bool", dtypes.bool_), + col_def("Char", dtypes.char), + col_def("Short", dtypes.short), + col_def("Int", dtypes.int32), + col_def("Long", dtypes.int64), + col_def("Float", dtypes.float32), + col_def("Double", dtypes.float64), + col_def("String", dtypes.string), + col_def("Instant", dtypes.Instant), + ], + list(self.test_definition.values()), + ) + + def test_items(self): + self.assertEquals( + [ + ("Bool", col_def("Bool", dtypes.bool_)), + ("Char", col_def("Char", dtypes.char)), + ("Short", col_def("Short", dtypes.short)), + ("Int", col_def("Int", dtypes.int32)), + ("Long", col_def("Long", dtypes.int64)), + ("Float", col_def("Float", dtypes.float32)), + ("Double", col_def("Double", dtypes.float64)), + ("String", col_def("String", dtypes.string)), + ("Instant", col_def("Instant", dtypes.Instant)), + ], + list(self.test_definition.items()), + ) + + def test_equals_hash_and_from_columns(self): + expected_hash = hash(self.test_definition) + for actual in [ + # should be equal to the same exact object + self.test_definition, + # should be equal to a new python object, but same underlying java object + TableDefinition(self.test_definition), + # should be equal to a new python object and new underlying java object + TableDefinition(self.test_definition.values()), + ]: + self.assertEquals(actual, self.test_definition) + self.assertEquals(hash(actual), expected_hash) + + def test_meta_table(self): + expected = new_table( + [ + string_col( + "Name", + [ + "Bool", + "Char", + "Short", + "Int", + "Long", + "Float", + "Double", + "String", + "Instant", + ], + ), + string_col( + "DataType", + [ + "java.lang.Boolean", + "char", + "short", + "int", + "long", + "float", + "double", + "java.lang.String", + "java.time.Instant", + ], + ), + string_col("ColumnType", ["Normal"] * 9), + bool_col("IsPartitioning", [False] * 9), + ] + ) + + self.assert_table_equals(self.test_definition.table, expected) + + def test_from_TableDefinition(self): + self.assertEquals(TableDefinition(self.test_definition), self.test_definition) + + def test_from_JpyJType(self): + self.assertEquals( + TableDefinition(self.test_definition.j_table_definition), + self.test_definition, + ) + + def test_from_Mapping(self): + # This case is already tested, it's how self.test_definition is created + pass + + def test_from_Iterable(self): + self.assertEquals( + TableDefinition(self.test_definition.values()), self.test_definition + ) + self.assertEquals( + TableDefinition(list(self.test_definition.values())), self.test_definition + ) + + def test_from_unexpected_type(self): + with self.assertRaises(DHError): + TableDefinition(42) + + def test_bad_Mapping_key(self): + with self.assertRaises(DHError): + TableDefinition( + { + "Foo": dtypes.int32, + 42: dtypes.string, + } + ) + + def test_bad_Mapping_value(self): + with self.assertRaises(DHError): + TableDefinition( + { + "Foo": dtypes.int32, + "Bar": 42, + } + ) + + def test_bad_Iterable(self): + with self.assertRaises(DHError): + TableDefinition([col_def("Foo", dtypes.int32), 42]) + + +if __name__ == "__main__": + unittest.main() diff --git a/py/server/tests/test_table_factory.py b/py/server/tests/test_table_factory.py index 3b1cfe55062..fe66ba992ef 100644 --- a/py/server/tests/test_table_factory.py +++ b/py/server/tests/test_table_factory.py @@ -41,7 +41,7 @@ def tearDown(self) -> None: def test_empty_table(self): t = empty_table(10) - self.assertEqual(0, len(t.columns)) + self.assertEqual(0, len(t.definition)) def test_empty_table_error(self): with self.assertRaises(DHError) as cm: @@ -52,22 +52,22 @@ def test_empty_table_error(self): def test_time_table(self): t = time_table("PT00:00:01") - self.assertEqual(1, len(t.columns)) + self.assertEqual(1, len(t.definition)) self.assertTrue(t.is_refreshing) t = time_table("PT00:00:01", start_time="2021-11-06T13:21:00 ET") - self.assertEqual(1, len(t.columns)) + self.assertEqual(1, len(t.definition)) self.assertTrue(t.is_refreshing) self.assertEqual("2021-11-06T13:21:00.000000000 ET", _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), time.to_j_time_zone('ET'))) t = time_table(1000_000_000) - self.assertEqual(1, len(t.columns)) + self.assertEqual(1, len(t.definition)) self.assertTrue(t.is_refreshing) t = time_table(1000_1000_1000, start_time="2021-11-06T13:21:00 ET") - self.assertEqual(1, len(t.columns)) + self.assertEqual(1, len(t.definition)) self.assertTrue(t.is_refreshing) self.assertEqual("2021-11-06T13:21:00.000000000 ET", _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), @@ -75,12 +75,12 @@ def test_time_table(self): p = time.to_timedelta(time.to_j_duration("PT1s")) t = time_table(p) - self.assertEqual(1, len(t.columns)) + self.assertEqual(1, len(t.definition)) self.assertTrue(t.is_refreshing) st = time.to_datetime(time.to_j_instant("2021-11-06T13:21:00 ET")) t = time_table(p, start_time=st) - self.assertEqual(1, len(t.columns)) + self.assertEqual(1, len(t.definition)) self.assertTrue(t.is_refreshing) self.assertEqual("2021-11-06T13:21:00.000000000 ET", _JDateTimeUtils.formatDateTime(t.j_table.getColumnSource("Timestamp").get(0), @@ -88,7 +88,7 @@ def test_time_table(self): def test_time_table_blink(self): t = time_table("PT1s", blink_table=True) - self.assertEqual(1, len(t.columns)) + self.assertEqual(1, len(t.definition)) self.assertTrue(t.is_blink) def test_time_table_error(self): @@ -325,19 +325,18 @@ def test_input_table(self): ] t = new_table(cols=cols) self.assertEqual(t.size, 2) - col_defs = {c.name: c.data_type for c in t.columns} with self.subTest("from table definition"): - append_only_input_table = input_table(col_defs=col_defs) + append_only_input_table = input_table(col_defs=t.definition) self.assertEqual(append_only_input_table.key_names, []) - self.assertEqual(append_only_input_table.value_names, [col.name for col in cols]) + self.assertEqual(append_only_input_table.value_names, [col._column_definition.name for col in cols]) append_only_input_table.add(t) self.assertEqual(append_only_input_table.size, 2) append_only_input_table.add(t) self.assertEqual(append_only_input_table.size, 4) - keyed_input_table = input_table(col_defs=col_defs, key_cols="String") + keyed_input_table = input_table(col_defs=t.definition, key_cols="String") self.assertEqual(keyed_input_table.key_names, ["String"]) - self.assertEqual(keyed_input_table.value_names, [col.name for col in cols if col.name != "String"]) + self.assertEqual(keyed_input_table.value_names, [col._column_definition.name for col in cols if col._column_definition.name != "String"]) keyed_input_table.add(t) self.assertEqual(keyed_input_table.size, 2) keyed_input_table.add(t) @@ -346,14 +345,14 @@ def test_input_table(self): with self.subTest("from init table"): append_only_input_table = input_table(init_table=t) self.assertEqual(append_only_input_table.key_names, []) - self.assertEqual(append_only_input_table.value_names, [col.name for col in cols]) + self.assertEqual(append_only_input_table.value_names, [col._column_definition.name for col in cols]) self.assertEqual(append_only_input_table.size, 2) append_only_input_table.add(t) self.assertEqual(append_only_input_table.size, 4) keyed_input_table = input_table(init_table=t, key_cols="String") self.assertEqual(keyed_input_table.key_names, ["String"]) - self.assertEqual(keyed_input_table.value_names, [col.name for col in cols if col.name != "String"]) + self.assertEqual(keyed_input_table.value_names, [col._column_definition.name for col in cols if col._column_definition.name != "String"]) self.assertEqual(keyed_input_table.size, 2) keyed_input_table.add(t) self.assertEqual(keyed_input_table.size, 2) @@ -368,7 +367,7 @@ def test_input_table(self): keyed_input_table = input_table(init_table=t, key_cols=["String", "Double"]) self.assertEqual(keyed_input_table.key_names, ["String", "Double"]) - self.assertEqual(keyed_input_table.value_names, [col.name for col in cols if col.name != "String" and col.name != "Double"]) + self.assertEqual(keyed_input_table.value_names, [col._column_definition.name for col in cols if col._column_definition.name != "String" and col._column_definition.name != "Double"]) self.assertEqual(keyed_input_table.size, 2) keyed_input_table.delete(t.select(["String", "Double"])) self.assertEqual(keyed_input_table.size, 0) @@ -449,7 +448,7 @@ def test_input_table_empty_data(self): with cm: t = time_table("PT1s", blink_table=True) - it = input_table({c.name: c.data_type for c in t.columns}, key_cols="Timestamp") + it = input_table(t.definition, key_cols="Timestamp") it.add(t) self.assertEqual(it.size, 0) it.delete(t) @@ -467,8 +466,7 @@ def test_j_input_wrapping(self): string_col(name="String", data=["foo", "bar"]), ] t = new_table(cols=cols) - col_defs = {c.name: c.data_type for c in t.columns} - append_only_input_table = input_table(col_defs=col_defs) + append_only_input_table = input_table(col_defs=t.definition) it = _wrapper.wrap_j_object(append_only_input_table.j_table) self.assertTrue(isinstance(it, InputTable)) diff --git a/py/server/tests/test_table_iterator.py b/py/server/tests/test_table_iterator.py index 465ba913453..0bb617fa6f5 100644 --- a/py/server/tests/test_table_iterator.py +++ b/py/server/tests/test_table_iterator.py @@ -22,7 +22,7 @@ def test_iteration_in_chunks(self): test_table = read_csv("tests/data/test_table.csv") total_read_size = 0 for d in test_table.iter_chunk_dict(chunk_size=10): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) self.assertEqual(d[col.name].dtype, col.data_type.np_type) @@ -36,7 +36,7 @@ def test_iteration_in_chunks(self): test_table.await_update() total_read_size = 0 for d in test_table.iter_chunk_dict(chunk_size=100): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) self.assertEqual(d[col.name].dtype, col.data_type.np_type) @@ -65,7 +65,7 @@ def test_iteration_in_rows(self): test_table = read_csv("tests/data/test_table.csv") total_read_size = 0 for d in test_table.iter_dict(): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) self.assertTrue(np.can_cast(col.data_type.np_type, np.dtype(type(d[col.name])))) @@ -77,7 +77,7 @@ def test_iteration_in_rows(self): test_table.await_update() total_read_size = 0 for d in test_table.iter_dict(): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) v_type = type(d[col.name]) @@ -108,7 +108,7 @@ def test_direct_call_chunks(self): test_table = read_csv("tests/data/test_table.csv") t_iter = test_table.iter_chunk_dict(chunk_size=10) for d in t_iter: - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) self.assertEqual(d[col.name].dtype, col.data_type.np_type) @@ -159,7 +159,7 @@ def test_direct_call_rows(self): test_table = read_csv("tests/data/test_table.csv") t_iter = test_table.iter_dict() for d in t_iter: - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) self.assertTrue(np.can_cast(col.data_type.np_type, np.dtype(type(d[col.name])))) @@ -232,7 +232,7 @@ class CustomClass: with self.subTest("Chunks"): for d in test_table.iter_chunk_dict(chunk_size=10): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) self.assertEqual(dtypes.from_np_dtype(d[col.name].dtype).np_type, col.data_type.np_type) @@ -240,7 +240,7 @@ class CustomClass: with self.subTest("Rows"): for d in test_table.iter_dict(): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for col in test_table.columns: self.assertIn(col.name, d) v_type = type(d[col.name]) @@ -258,7 +258,7 @@ def test_iteration_in_chunks_tuple(self): test_table = read_csv("tests/data/test_table.csv") total_read_size = 0 for d in test_table.iter_chunk_tuple(chunk_size=10): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for i, col in enumerate(test_table.columns): self.assertEqual(col.name, d._fields[i]) self.assertEqual(d[i].dtype, col.data_type.np_type) @@ -272,7 +272,7 @@ def test_iteration_in_chunks_tuple(self): test_table.await_update() total_read_size = 0 for d in test_table.iter_chunk_tuple(chunk_size=100): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for i, col in enumerate(test_table.columns): self.assertEqual(col.name, d._fields[i]) self.assertEqual(d[i].dtype, col.data_type.np_type) @@ -301,7 +301,7 @@ def test_iteration_in_rows_tuple(self): test_table = read_csv("tests/data/test_table.csv") total_read_size = 0 for d in test_table.iter_tuple(): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for i, col in enumerate(test_table.columns): self.assertEqual(col.name, d._fields[i]) self.assertTrue(np.can_cast(col.data_type.np_type, np.dtype(type(d[i])))) @@ -313,7 +313,7 @@ def test_iteration_in_rows_tuple(self): test_table.await_update() total_read_size = 0 for d in test_table.iter_tuple(): - self.assertEqual(len(d), len(test_table.columns)) + self.assertEqual(len(d), len(test_table.definition)) for i, col in enumerate(test_table.columns): self.assertEqual(col.name, d._fields[i]) v_type = type(d[i]) diff --git a/py/server/tests/test_table_listener.py b/py/server/tests/test_table_listener.py index db570b77414..48915a9c277 100644 --- a/py/server/tests/test_table_listener.py +++ b/py/server/tests/test_table_listener.py @@ -104,7 +104,7 @@ def verify_data_changes(self, changes, cols: Union[str, List[str]]): for change in changes: self.assertTrue(isinstance(change, dict)) if not cols: - cols = [col.name for col in self.test_table.columns] + cols = self.test_table.column_names for col in cols: self.assertIn(col, change.keys()) self.assertTrue(isinstance(change[col], numpy.ndarray)) @@ -274,8 +274,7 @@ def test_listener_func_with_deps(self): ] t = new_table(cols=cols) self.assertEqual(t.size, 2) - col_defs = {c.name: c.data_type for c in t.columns} - dep_table = input_table(col_defs=col_defs) + dep_table = input_table(col_defs=t.definition) def listener_func(update, is_replay): table_update_recorder.record(update, is_replay) diff --git a/py/server/tests/test_updateby.py b/py/server/tests/test_updateby.py index e4ecbc2aae2..e58ce542539 100644 --- a/py/server/tests/test_updateby.py +++ b/py/server/tests/test_updateby.py @@ -177,7 +177,7 @@ def test_em(self): for t in (self.static_table, self.ticking_table): rt = t.update_by(ops=op, by="b") self.assertTrue(rt.is_refreshing is t.is_refreshing) - self.assertEqual(len(rt.columns), 1 + len(t.columns)) + self.assertEqual(len(rt.definition), 1 + len(t.definition)) with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(rt.size, t.size) @@ -192,7 +192,7 @@ def test_em_proxy(self): rt_proxy = pt_proxy.update_by(op, by="e") for ct, rct in zip(pt_proxy.target.constituent_tables, rt_proxy.target.constituent_tables): self.assertTrue(rct.is_refreshing is ct.is_refreshing) - self.assertEqual(len(rct.columns), 1 + len(ct.columns)) + self.assertEqual(len(rct.definition), 1 + len(ct.definition)) with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(ct.size, rct.size) @@ -202,7 +202,7 @@ def test_simple_ops(self): for t in (self.static_table, self.ticking_table): rt = t.update_by(ops=op, by="e") self.assertTrue(rt.is_refreshing is t.is_refreshing) - self.assertEqual(len(rt.columns), 2 + len(t.columns)) + self.assertEqual(len(rt.definition), 2 + len(t.definition)) with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(rt.size, t.size) @@ -230,7 +230,7 @@ def test_rolling_ops(self): for t in (self.static_table, self.ticking_table): rt = t.update_by(ops=op, by="c") self.assertTrue(rt.is_refreshing is t.is_refreshing) - self.assertEqual(len(rt.columns), 2 + len(t.columns)) + self.assertEqual(len(rt.definition), 2 + len(t.definition)) with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(rt.size, t.size) @@ -245,7 +245,7 @@ def test_rolling_ops_proxy(self): rt_proxy = pt_proxy.update_by(op, by="c") for ct, rct in zip(pt_proxy.target.constituent_tables, rt_proxy.target.constituent_tables): self.assertTrue(rct.is_refreshing is ct.is_refreshing) - self.assertEqual(len(rct.columns), 2 + len(ct.columns)) + self.assertEqual(len(rct.definition), 2 + len(ct.definition)) with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(ct.size, rct.size) @@ -260,7 +260,7 @@ def test_multiple_ops(self): for t in (self.static_table, self.ticking_table): rt = t.update_by(ops=multiple_ops, by="c") self.assertTrue(rt.is_refreshing is t.is_refreshing) - self.assertEqual(len(rt.columns), 10 + len(t.columns)) + self.assertEqual(len(rt.definition), 10 + len(t.definition)) with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(rt.size, t.size) diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index ebac32aff93..ab227d02cc5 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -234,7 +234,7 @@ def my_sum(*args): source = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols]) result = source.update(f"X = my_sum({','.join(cols)})") - self.assertEqual(len(cols) + 1, len(result.columns)) + self.assertEqual(len(cols) + 1, len(result.definition)) self.assertEqual(_udf.vectorized_count, 0) def test_enclosed_by_parentheses(self):