Skip to content

Commit

Permalink
Validate arg type for sort/partition_by (#5652)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Jun 22, 2024
1 parent c735716 commit ffec8b7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
22 changes: 15 additions & 7 deletions py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,12 @@ def sort(self, order_by: Union[str, Sequence[str]],
order_by = to_sequence(order_by)
if not order:
order = (SortDirection.ASCENDING,) * len(order_by)
order = to_sequence(order)
if len(order_by) != len(order):
raise DHError(message="The number of sort columns must be the same as the number of sort directions.")
else:
order = to_sequence(order)
if any([o not in (SortDirection.ASCENDING, SortDirection.DESCENDING) for o in order]):
raise DHError(message="The sort direction must be either 'ASCENDING' or 'DESCENDING'.")
if len(order_by) != len(order):
raise DHError(message="The number of sort columns must be the same as the number of sort directions.")

sort_columns = [_sort_column(col, dir_) for col, dir_ in zip(order_by, order)]
j_sc_list = j_array_list(sort_columns)
Expand Down Expand Up @@ -2008,6 +2011,9 @@ def partition_by(self, by: Union[str, Sequence[str]], drop_keys: bool = False) -
DHError
"""
try:
if not isinstance(drop_keys, bool):
raise DHError(message="drop_keys must be a boolean value.")

by = to_sequence(by)
return PartitionedTable(j_partitioned_table=self.j_table.partitionBy(drop_keys, *by))
except Exception as e:
Expand Down Expand Up @@ -2737,12 +2743,14 @@ def sort(self, order_by: Union[str, Sequence[str]],
DHError
"""
try:
order_by = to_sequence(order_by)
if not order:
order = (SortDirection.ASCENDING,) * len(order_by)
order = to_sequence(order)
if len(order_by) != len(order):
raise ValueError("The number of sort columns must be the same as the number of sort directions.")
else:
order = to_sequence(order)
if any([o not in (SortDirection.ASCENDING, SortDirection.DESCENDING) for o in order]):
raise DHError(message="The sort direction must be either 'ASCENDING' or 'DESCENDING'.")
if len(order_by) != len(order):
raise DHError(message="The number of sort columns must be the same as the number of sort directions.")

sort_columns = [_sort_column(col, dir_) for col, dir_ in zip(order_by, order)]
j_sc_list = j_array_list(sort_columns)
Expand Down
2 changes: 1 addition & 1 deletion py/server/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def verify_table_from_disk(table):
self.assertTrue(len(table.columns))
self.assertTrue(table.columns[0].name == "X")
self.assertTrue(table.columns[0].column_type == ColumnType.PARTITIONING)
self.assert_table_equals(table.select().sort("X", "Y"), source.sort("X", "Y"))
self.assert_table_equals(table.select().sort(["X", "Y"]), source.sort(["X", "Y"]))

def verify_file_names():
partition_dir_path = os.path.join(root_dir, 'X=Aa')
Expand Down
10 changes: 10 additions & 0 deletions py/server/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,16 @@ def my_fn(vals):
t = partitioned_by_formula()
self.assertIsNotNone(t)

def test_arg_validation(self):
t = empty_table(1).update(["A=i", "B=i", "C=i"])
with self.assertRaises(DHError) as cm:
t.sort("A", "B")
self.assertIn("The sort direction must be", str(cm.exception))

with self.assertRaises(DHError) as cm:
t.partition_by("A", "B")
self.assertIn("drop_keys must be", str(cm.exception))


if __name__ == "__main__":
unittest.main()

0 comments on commit ffec8b7

Please sign in to comment.