Skip to content

Commit

Permalink
Add more tests and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Jul 10, 2024
1 parent 8cb33ca commit ef536b7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.jpy.PyObject;

import java.util.Arrays;
import java.util.Objects;
import java.util.stream.Stream;

/**
* A Deephaven merged listener which fires when any of its bound listener recorders has updates and all of its
Expand Down Expand Up @@ -49,15 +51,25 @@ public static PythonMergedListenerAdapter create(
@Nullable NotificationQueue.Dependency[] dependencies,
@Nullable String listenerDescription,
@NotNull PyObject pyObjectIn) {
final UpdateGraph updateGraph = ExecutionContext.getContext().getUpdateGraph();

if (!Arrays.stream(recorders).allMatch(t -> t.getParent().getUpdateGraph() == updateGraph)) {
throw new IllegalArgumentException("All recorders must be from the same update graph");
if (recorders.length < 2) {
throw new IllegalArgumentException("At least two recorders must be provided");
}
// TODO: Uncomment this check if confirmed that the alternative way of checking is better
// final UpdateGraph updateGraph = ExecutionContext.getContext().getUpdateGraph();
// if (!Arrays.stream(recorders).allMatch(t -> t.getParent().getUpdateGraph() == updateGraph)) {
// throw new IllegalArgumentException("All recorders must be from the same update graph");
// }
//
// if (!Arrays.stream(dependencies).allMatch(t -> t.getUpdateGraph() == updateGraph)) {
// throw new IllegalArgumentException("All dependencies must be from the same update graph");
// }

if (!Arrays.stream(dependencies).allMatch(t -> t.getUpdateGraph() == updateGraph)) {
throw new IllegalArgumentException("All dependencies must be from the same update graph");
}
final NotificationQueue.Dependency[] allItems =
Stream.concat(Arrays.stream(recorders), Arrays.stream(dependencies))
.filter(Objects::nonNull)
.toArray(NotificationQueue.Dependency[]::new);

final UpdateGraph updateGraph = allItems[0].getUpdateGraph(allItems);

try (final SafeCloseable ignored = ExecutionContext.getContext().withUpdateGraph(updateGraph).open()) {
return new PythonMergedListenerAdapter(recorders, dependencies, listenerDescription, pyObjectIn);
Expand Down
23 changes: 14 additions & 9 deletions py/server/deephaven/table_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ def __init__(self, listener_recorders: Sequence[ListenerRecorder], listener: Uni
the listener is safe, it is not recommended because reading or operating on the result tables of those
operations may not be safe. It is best to perform the operations on the dependent tables beforehand,
and then add the result tables as dependencies to the listener so that they can be safely read in it.
Raises:
DHError
"""
if len(listener_recorders) < 2:
raise DHError(message="MergedListener must have at least two listener recorders.")
Expand All @@ -550,20 +553,23 @@ def __init__(self, listener_recorders: Sequence[ListenerRecorder], listener: Uni

if isinstance(listener, MergedListener):
listener.listener_recorders = listener_recorders
self.merged_listener_adapter = _JPythonMergedListenerAdapter.create(
to_sequence(self.listener_recorders),
to_sequence(self.dependencies),
description,
listener)
self.started = False

try:
self.merged_listener_adapter = _JPythonMergedListenerAdapter.create(
to_sequence(self.listener_recorders),
to_sequence(self.dependencies),
description,
listener)
self.started = False
except Exception as e:
raise DHError(e, "failed to create a merged listener adapter.") from e


def start(self) -> None:
"""Start the listener."""
if self.started:
raise RuntimeError("Attempting to start an already started listener..")
raise RuntimeError("Attempting to start an already started merged listener..")

# TODO - move to the Java side?
with update_graph.shared_lock(self.listener_recorders[0].table.update_graph):
for lr in self.listener_recorders:
lr.table.j_table.addUpdateListener(lr.j_listener_recorder)
Expand All @@ -574,7 +580,6 @@ def stop(self) -> None:
if not self.started:
return

# TODO - move to the Java side?
with update_graph.shared_lock(self.listener_recorders[0].table.update_graph):
for lr in self.listener_recorders:
lr.table.j_table.removeUpdateListener(lr.j_listener_recorder)
Expand Down
42 changes: 38 additions & 4 deletions py/server/tests/test_table_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy
import jpy

from deephaven import time_table, new_table, input_table, DHError
from deephaven import time_table, new_table, input_table, DHError, empty_table
from deephaven.column import bool_col, string_col
from deephaven.experimental import time_window
from deephaven.jcompat import to_sequence
Expand Down Expand Up @@ -361,6 +361,10 @@ def process(self) -> None:
mlh.stop()
self.assertGreaterEqual(len(tur.replays), 6)

with self.subTest("Error input"):
et = empty_table(1)
with self.assertRaises(DHError):
mlh = MergedListenerHandle([ListenerRecorder(t) for t in [t1, t2, t3, et]], tml)

def test_merged_listener_func(self):
t1 = time_table("PT1s").update(["X=i % 11"])
Expand Down Expand Up @@ -394,11 +398,41 @@ def test_ml_func() -> None:
mlh.stop()
self.assertGreaterEqual(len(tur.replays), 6)

with self.subTest("Error input"):
et = empty_table(1)
with self.assertRaises(DHError):
mlh = merged_listen([ListenerRecorder(t) for t in [t1, t2, t3, et]], test_ml_func)

def test_merged_listener_with_deps(self):
...
t1 = time_table("PT1s").update(["X=i % 11"])
t2 = time_table("PT2s").update(["Y=i % 8"])
t3 = time_table("PT3s").update(["Z=i % 5"])

def test_merged_listener_with_deps_error(self):
...
dep_table = time_table("PT00:00:05").update("X = i % 11")
ec = get_exec_ctx()

tur = TableUpdateRecorder()
j_arrays = []
class TestMergedListener(MergedListener):
def process(self) -> None:
for i, listener in enumerate(self.listener_recorders):
if self.listener_recorders[i].table_update():
tur.record(self.listener_recorders[i].table_update())

with ec:
t = dep_table.view(["Y = i % 8"])
j_arrays.append(_JColumnVectors.of(t.j_table, "Y").copyToArray())

tml = TestMergedListener()
mlh = MergedListenerHandle(listener_recorders=[ListenerRecorder(t) for t in [t1, t2, t3]], listener=tml, dependencies=dep_table)
mlh.start()
ensure_ugp_cycles(tur, cycles=3)
mlh.stop()
mlh.start()
ensure_ugp_cycles(tur, cycles=6)
mlh.stop()
self.assertGreaterEqual(len(tur.replays), 6)
self.assertTrue(len(j_arrays) > 0 and all([len(ja) > 0 for ja in j_arrays]))


if __name__ == "__main__":
Expand Down

0 comments on commit ef536b7

Please sign in to comment.