From 953d0f2e6748eafc6b99d668a3ef51f26f346a16 Mon Sep 17 00:00:00 2001 From: Jianfeng Mao <4297243+jmao-denver@users.noreply.github.com> Date: Thu, 16 Nov 2023 20:02:14 -0700 Subject: [PATCH] Apply query scope ctx on formula aggregation (#4839) * Apply query scope ctx on formula aggregation * Add more tests --- py/server/deephaven/agg.py | 4 ++ py/server/deephaven/table.py | 66 ++++++++++++++++++++++---------- py/server/tests/test_pt_proxy.py | 18 +++++++++ py/server/tests/test_table.py | 46 +++++++++++++++++++++- 4 files changed, 112 insertions(+), 22 deletions(-) diff --git a/py/server/deephaven/agg.py b/py/server/deephaven/agg.py index 7e4c4293f73..09e985b5330 100644 --- a/py/server/deephaven/agg.py +++ b/py/server/deephaven/agg.py @@ -44,6 +44,10 @@ def j_agg_spec(self): raise DHError(message="unsupported aggregation operation.") return self._j_agg_spec + @property + def is_formula(self): + return isinstance(self._j_agg_spec, jpy.get_type("io.deephaven.api.agg.spec.AggSpecFormula")) + def sum_(cols: Union[str, List[str]] = None) -> Aggregation: """Creates a Sum aggregation. diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index 1a55c58f823..e46348667b9 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -563,6 +563,15 @@ def _query_scope_ctx(): yield +def _query_scope_agg_ctx(aggs: Sequence[Aggregation]) -> contextlib.AbstractContextManager: + has_agg_formula = any([agg.is_formula for agg in aggs]) + if has_agg_formula: + cm = _query_scope_ctx() + else: + cm = contextlib.nullcontext() + return cm + + class SortDirection(Enum): """An enum defining the sorting orders.""" DESCENDING = auto() @@ -1961,13 +1970,16 @@ def agg_by(self, aggs: Union[Aggregation, Sequence[Aggregation]], by: Union[str, if not by and initial_groups: raise ValueError("missing group-by column names when initial_groups is provided.") j_agg_list = j_array_list([agg.j_aggregation for agg in aggs]) - if not by: - return Table(j_table=self.j_table.aggBy(j_agg_list, preserve_empty)) - else: - j_column_name_list = j_array_list([_JColumnName.of(col) for col in by]) - initial_groups = unwrap(initial_groups) - return Table( - j_table=self.j_table.aggBy(j_agg_list, preserve_empty, initial_groups, j_column_name_list)) + + cm = _query_scope_agg_ctx(aggs) + with cm: + if not by: + return Table(j_table=self.j_table.aggBy(j_agg_list, preserve_empty)) + else: + j_column_name_list = j_array_list([_JColumnName.of(col) for col in by]) + initial_groups = unwrap(initial_groups) + return Table( + j_table=self.j_table.aggBy(j_agg_list, preserve_empty, initial_groups, j_column_name_list)) except Exception as e: raise DHError(e, "table agg_by operation failed.") from e @@ -2004,8 +2016,11 @@ def partitioned_agg_by(self, aggs: Union[Aggregation, Sequence[Aggregation]], by = to_sequence(by) j_agg_list = j_array_list([agg.j_aggregation for agg in aggs]) initial_groups = unwrap(initial_groups) - return PartitionedTable( - j_partitioned_table=self.j_table.partitionedAggBy(j_agg_list, preserve_empty, initial_groups, *by)) + + cm = _query_scope_agg_ctx(aggs) + with cm: + return PartitionedTable( + j_partitioned_table=self.j_table.partitionedAggBy(j_agg_list, preserve_empty, initial_groups, *by)) except Exception as e: raise DHError(e, "table partitioned_agg_by operation failed.") from e @@ -2028,7 +2043,9 @@ def agg_all_by(self, agg: Aggregation, by: Union[str, Sequence[str]] = None) -> """ try: by = to_sequence(by) - return Table(j_table=self.j_table.aggAllBy(agg.j_agg_spec, *by)) + cm = _query_scope_agg_ctx([agg]) + with cm: + return Table(j_table=self.j_table.aggAllBy(agg.j_agg_spec, *by)) except Exception as e: raise DHError(e, "table agg_all_by operation failed.") from e @@ -2276,12 +2293,15 @@ def rollup(self, aggs: Union[Aggregation, Sequence[Aggregation]], by: Union[str, aggs = to_sequence(aggs) by = to_sequence(by) j_agg_list = j_array_list([agg.j_aggregation for agg in aggs]) - if not by: - return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents), aggs=aggs, - include_constituents=include_constituents, by=by) - else: - return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents, by), - aggs=aggs, include_constituents=include_constituents, by=by) + + cm = _query_scope_agg_ctx(aggs) + with cm: + if not by: + return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents), aggs=aggs, + include_constituents=include_constituents, by=by) + else: + return RollupTable(j_rollup_table=self.j_table.rollup(j_agg_list, include_constituents, by), + aggs=aggs, include_constituents=include_constituents, by=by) except Exception as e: raise DHError(e, "table rollup operation failed.") from e @@ -3299,8 +3319,11 @@ def agg_by(self, aggs: Union[Aggregation, Sequence[Aggregation]], aggs = to_sequence(aggs) by = to_sequence(by) j_agg_list = j_array_list([agg.j_aggregation for agg in aggs]) - with auto_locking_ctx(self): - return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggBy(j_agg_list, *by)) + + cm = _query_scope_agg_ctx(aggs) + with cm: + with auto_locking_ctx(self): + return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggBy(j_agg_list, *by)) except Exception as e: raise DHError(e, "agg_by operation on the PartitionedTableProxy failed.") from e @@ -3324,8 +3347,11 @@ def agg_all_by(self, agg: Aggregation, by: Union[str, Sequence[str]] = None) -> """ try: by = to_sequence(by) - with auto_locking_ctx(self): - return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggAllBy(agg.j_agg_spec, *by)) + + cm = _query_scope_agg_ctx([agg]) + with cm: + with auto_locking_ctx(self): + return PartitionedTableProxy(j_pt_proxy=self.j_pt_proxy.aggAllBy(agg.j_agg_spec, *by)) except Exception as e: raise DHError(e, "agg_all_by operation on the PartitionedTableProxy failed.") from e diff --git a/py/server/tests/test_pt_proxy.py b/py/server/tests/test_pt_proxy.py index fcb802f1a76..99c7414aa55 100644 --- a/py/server/tests/test_pt_proxy.py +++ b/py/server/tests/test_pt_proxy.py @@ -345,6 +345,24 @@ def local_fn() -> str: self.assertIsNotNone(inner_func("param str")) + @unittest.skip("https://github.com/deephaven/deephaven-core/issues/4847") + def test_agg_formula_scope(self): + with self.subTest("agg_by_formula"): + def agg_by_formula(): + def my_fn(vals): + import deephaven.dtypes as dht + return dht.array(dht.double, [i + 2 for i in vals]) + + t = empty_table(1000).update_view(["A=i%2", "B=A+3"]) + pt_proxy = t.partition_by("A").proxy() + rlt_pt_proxy = pt_proxy.agg_by([formula("(double[])my_fn(each)", formula_param='each', cols=['C=B']), + median("B")], + by='A') + return rlt_pt_proxy + + ptp = agg_by_formula() + self.assertIsNotNone(ptp) + def global_fn() -> str: return "global str" diff --git a/py/server/tests/test_table.py b/py/server/tests/test_table.py index 816c4e74f99..334d9415a3c 100644 --- a/py/server/tests/test_table.py +++ b/py/server/tests/test_table.py @@ -943,7 +943,7 @@ def make_pairs_3(tid, a, b): def test_callable_attrs_in_query(self): input_cols = [ - datetime_col(name="DTCol", data=[1,10000000]), + datetime_col(name="DTCol", data=[1, 10000000]), ] test_table = new_table(cols=input_cols) rt = test_table.update("Year = (int)year(DTCol, timeZone(`ET`))") @@ -1025,7 +1025,7 @@ def test_agg_with_options(self): unique(cols=["ua = a", "ub = b"], include_nulls=True, non_unique_sentinel=-1), count_distinct(cols=["csa = a", "csb = b"], count_nulls=True), distinct(cols=["da = a", "db = b"], include_nulls=True), - ] + ] rt = test_table.agg_by(aggs=aggs, by=["c"]) self.assertEqual(rt.size, test_table.select_distinct(["c"]).size) @@ -1063,6 +1063,48 @@ def test_agg_count_and_partition_error(self): t.agg_by(aggs=[partition(["A"])], by=["B"]) self.assertIn("string value", str(cm.exception)) + def test_agg_formula_scope(self): + with self.subTest("agg_by_formula"): + def agg_by_formula(): + def my_fn(vals): + import deephaven.dtypes as dht + return dht.array(dht.double, [i + 2 for i in vals]) + + t = empty_table(1000).update_view(["A=i%2", "B=A+3"]) + t = t.agg_by([formula("(double[])my_fn(each)", formula_param='each', cols=['C=B']), median("B")], + by='A') + return t + + t = agg_by_formula() + self.assertIsNotNone(t) + + with self.subTest("agg_all_by_formula"): + def agg_all_by_formula(): + def my_fn(vals): + import deephaven.dtypes as dht + return dht.array(dht.double, [i + 2 for i in vals]) + + t = empty_table(1000).update_view(["A=i%2", "B=A+3"]) + t = t.agg_all_by(formula("(double[])my_fn(each)", formula_param='each', cols=['C=B']), by='A') + return t + + t = agg_all_by_formula() + self.assertIsNotNone(t) + + with self.subTest("partitioned_by_formula"): + def partitioned_by_formula(): + def my_fn(vals): + import deephaven.dtypes as dht + return dht.array(dht.double, [i + 2 for i in vals]) + + t = empty_table(10).update(["grp_id=(int)(i/5)", "var=(int)i", "weights=(double)1.0/(i+1)"]) + t = t.partitioned_agg_by(aggs=formula("(double[])my_fn(each)", formula_param='each', + cols=['C=weights']), by="grp_id") + return t + + t = partitioned_by_formula() + self.assertIsNotNone(t) + if __name__ == "__main__": unittest.main()