Skip to content

Commit

Permalink
Apply query scope ctx on formula aggregation (#4839)
Browse files Browse the repository at this point in the history
* Apply query scope ctx on formula aggregation

* Add more tests
  • Loading branch information
jmao-denver authored Nov 17, 2023
1 parent 6b1df15 commit 953d0f2
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 22 deletions.
4 changes: 4 additions & 0 deletions py/server/deephaven/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def j_agg_spec(self):
raise DHError(message="unsupported aggregation operation.")
return self._j_agg_spec

@property
def is_formula(self):
return isinstance(self._j_agg_spec, jpy.get_type("io.deephaven.api.agg.spec.AggSpecFormula"))


def sum_(cols: Union[str, List[str]] = None) -> Aggregation:
"""Creates a Sum aggregation.
Expand Down
66 changes: 46 additions & 20 deletions py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,15 @@ def _query_scope_ctx():
yield


def _query_scope_agg_ctx(aggs: Sequence[Aggregation]) -> contextlib.AbstractContextManager:
has_agg_formula = any([agg.is_formula for agg in aggs])
if has_agg_formula:
cm = _query_scope_ctx()
else:
cm = contextlib.nullcontext()
return cm


class SortDirection(Enum):
"""An enum defining the sorting orders."""
DESCENDING = auto()
Expand Down Expand Up @@ -1961,13 +1970,16 @@ def agg_by(self, aggs: Union[Aggregation, Sequence[Aggregation]], by: Union[str,
if not by and initial_groups:
raise ValueError("missing group-by column names when initial_groups is provided.")
j_agg_list = j_array_list([agg.j_aggregation for agg in aggs])
if not by:
return Table(j_table=self.j_table.aggBy(j_agg_list, preserve_empty))
else:
j_column_name_list = j_array_list([_JColumnName.of(col) for col in by])
initial_groups = unwrap(initial_groups)
return Table(
j_table=self.j_table.aggBy(j_agg_list, preserve_empty, initial_groups, j_column_name_list))

cm = _query_scope_agg_ctx(aggs)
with cm:
if not by:
return Table(j_table=self.j_table.aggBy(j_agg_list, preserve_empty))
else:
j_column_name_list = j_array_list([_JColumnName.of(col) for col in by])
initial_groups = unwrap(initial_groups)
return Table(
j_table=self.j_table.aggBy(j_agg_list, preserve_empty, initial_groups, j_column_name_list))
except Exception as e:
raise DHError(e, "table agg_by operation failed.") from e

Expand Down Expand Up @@ -2004,8 +2016,11 @@ def partitioned_agg_by(self, aggs: Union[Aggregation, Sequence[Aggregation]],
by = to_sequence(by)
j_agg_list = j_array_list([agg.j_aggregation for agg in aggs])
initial_groups = unwrap(initial_groups)
return PartitionedTable(
j_partitioned_table=self.j_table.partitionedAggBy(j_agg_list, preserve_empty, initial_groups, *by))

cm = _query_scope_agg_ctx(aggs)
with cm:
return PartitionedTable(
j_partitioned_table=self.j_table.partitionedAggBy(j_agg_list, preserve_empty, initial_groups, *by))
except Exception as e:
raise DHError(e, "table partitioned_agg_by operation failed.") from e

Expand All @@ -2028,7 +2043,9 @@ def agg_all_by(self, agg: Aggregation, by: Union[str, Sequence[str]] = None) ->
"""
try:
by = to_sequence(by)
return Table(j_table=self.j_table.aggAllBy(agg.j_agg_spec, *by))
cm = _query_scope_agg_ctx([agg])
with cm:
return Table(j_table=self.j_table.aggAllBy(agg.j_agg_spec, *by))
except Exception as e:
raise DHError(e, "table agg_all_by operation failed.") from e

Expand Down Expand Up @@ -2276,12 +2293,15 @@ def rollup(self, aggs: Union[Aggregation, Sequence[Aggregation]], by: Union[str,
aggs = to_sequence(aggs)
by = to_sequence(by)
j_agg_list = j_array_list([agg.j_aggregation for agg in aggs])
if not by:
return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents), aggs=aggs,
include_constituents=include_constituents, by=by)
else:
return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents, by),
aggs=aggs, include_constituents=include_constituents, by=by)

cm = _query_scope_agg_ctx(aggs)
with cm:
if not by:
return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents), aggs=aggs,
include_constituents=include_constituents, by=by)
else:
return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents, by),
aggs=aggs, include_constituents=include_constituents, by=by)
except Exception as e:
raise DHError(e, "table rollup operation failed.") from e

Expand Down Expand Up @@ -3299,8 +3319,11 @@ def agg_by(self, aggs: Union[Aggregation, Sequence[Aggregation]],
aggs = to_sequence(aggs)
by = to_sequence(by)
j_agg_list = j_array_list([agg.j_aggregation for agg in aggs])
with auto_locking_ctx(self):
return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggBy(j_agg_list, *by))

cm = _query_scope_agg_ctx(aggs)
with cm:
with auto_locking_ctx(self):
return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggBy(j_agg_list, *by))
except Exception as e:
raise DHError(e, "agg_by operation on the PartitionedTableProxy failed.") from e

Expand All @@ -3324,8 +3347,11 @@ def agg_all_by(self, agg: Aggregation, by: Union[str, Sequence[str]] = None) ->
"""
try:
by = to_sequence(by)
with auto_locking_ctx(self):
return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggAllBy(agg.j_agg_spec, *by))

cm = _query_scope_agg_ctx([agg])
with cm:
with auto_locking_ctx(self):
return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggAllBy(agg.j_agg_spec, *by))
except Exception as e:
raise DHError(e, "agg_all_by operation on the PartitionedTableProxy failed.") from e

Expand Down
18 changes: 18 additions & 0 deletions py/server/tests/test_pt_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,24 @@ def local_fn() -> str:

self.assertIsNotNone(inner_func("param str"))

@unittest.skip("https://github.com/deephaven/deephaven-core/issues/4847")
def test_agg_formula_scope(self):
with self.subTest("agg_by_formula"):
def agg_by_formula():
def my_fn(vals):
import deephaven.dtypes as dht
return dht.array(dht.double, [i + 2 for i in vals])

t = empty_table(1000).update_view(["A=i%2", "B=A+3"])
pt_proxy = t.partition_by("A").proxy()
rlt_pt_proxy = pt_proxy.agg_by([formula("(double[])my_fn(each)", formula_param='each', cols=['C=B']),
median("B")],
by='A')
return rlt_pt_proxy

ptp = agg_by_formula()
self.assertIsNotNone(ptp)


def global_fn() -> str:
return "global str"
Expand Down
46 changes: 44 additions & 2 deletions py/server/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ def make_pairs_3(tid, a, b):

def test_callable_attrs_in_query(self):
input_cols = [
datetime_col(name="DTCol", data=[1,10000000]),
datetime_col(name="DTCol", data=[1, 10000000]),
]
test_table = new_table(cols=input_cols)
rt = test_table.update("Year = (int)year(DTCol, timeZone(`ET`))")
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def test_agg_with_options(self):
unique(cols=["ua = a", "ub = b"], include_nulls=True, non_unique_sentinel=-1),
count_distinct(cols=["csa = a", "csb = b"], count_nulls=True),
distinct(cols=["da = a", "db = b"], include_nulls=True),
]
]
rt = test_table.agg_by(aggs=aggs, by=["c"])
self.assertEqual(rt.size, test_table.select_distinct(["c"]).size)

Expand Down Expand Up @@ -1063,6 +1063,48 @@ def test_agg_count_and_partition_error(self):
t.agg_by(aggs=[partition(["A"])], by=["B"])
self.assertIn("string value", str(cm.exception))

def test_agg_formula_scope(self):
with self.subTest("agg_by_formula"):
def agg_by_formula():
def my_fn(vals):
import deephaven.dtypes as dht
return dht.array(dht.double, [i + 2 for i in vals])

t = empty_table(1000).update_view(["A=i%2", "B=A+3"])
t = t.agg_by([formula("(double[])my_fn(each)", formula_param='each', cols=['C=B']), median("B")],
by='A')
return t

t = agg_by_formula()
self.assertIsNotNone(t)

with self.subTest("agg_all_by_formula"):
def agg_all_by_formula():
def my_fn(vals):
import deephaven.dtypes as dht
return dht.array(dht.double, [i + 2 for i in vals])

t = empty_table(1000).update_view(["A=i%2", "B=A+3"])
t = t.agg_all_by(formula("(double[])my_fn(each)", formula_param='each', cols=['C=B']), by='A')
return t

t = agg_all_by_formula()
self.assertIsNotNone(t)

with self.subTest("partitioned_by_formula"):
def partitioned_by_formula():
def my_fn(vals):
import deephaven.dtypes as dht
return dht.array(dht.double, [i + 2 for i in vals])

t = empty_table(10).update(["grp_id=(int)(i/5)", "var=(int)i", "weights=(double)1.0/(i+1)"])
t = t.partitioned_agg_by(aggs=formula("(double[])my_fn(each)", formula_param='each',
cols=['C=weights']), by="grp_id")
return t

t = partitioned_by_formula()
self.assertIsNotNone(t)


if __name__ == "__main__":
unittest.main()

0 comments on commit 953d0f2

Please sign in to comment.