From ee1abf5637ce8e5b746d9c1e451eceb9b98a9e17 Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Tue, 4 Jun 2024 11:33:08 +0200 Subject: [PATCH] Make operations update inputs again --- pyproject.toml | 2 +- src/nodify/node.py | 10 +++++----- src/nodify/syntax_nodes.py | 8 ++++++++ src/nodify/tests/test_node.py | 15 ++++++++++++++- src/nodify/tests/test_workflow.py | 10 ++++++++++ 5 files changed, 38 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c116ac5..29bdcbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ name = "nodify" description = "Supercharge your functional application with a powerful node system." readme = "README.md" license = {file = "LICENSE"} -version = "0.0.10" +version = "0.0.11" dependencies = [] diff --git a/src/nodify/node.py b/src/nodify/node.py index 77b4e1f..da10851 100644 --- a/src/nodify/node.py +++ b/src/nodify/node.py @@ -654,15 +654,15 @@ def recursive_update_inputs( def _update(node): if cls is None or isinstance(self, cls): - node.update_inputs(**inputs) + sig = node.__signature__ update_inputs = {} - # Update the inputs of the node - for k in self.inputs: - if k in inputs: + for k in inputs: + if k in sig.parameters: update_inputs[k] = inputs[k] - self.update_inputs(**update_inputs) + if len(update_inputs) > 0: + node.update_inputs(**update_inputs) traverse_tree_backward([self], _update) diff --git a/src/nodify/syntax_nodes.py b/src/nodify/syntax_nodes.py index ceac0f7..ebe2053 100644 --- a/src/nodify/syntax_nodes.py +++ b/src/nodify/syntax_nodes.py @@ -202,6 +202,10 @@ class BinaryOperationNode(Node): "or": "|", } + def __call__(self, *args, **kwargs): + self.recursive_update_inputs(*args, **kwargs) + return self.get() + @staticmethod def function(left: Any, op: _BynaryOp, right: Any): return getattr(operator, op)(left, right) @@ -222,6 +226,10 @@ class UnaryOperationNode(Node): "pos": "+", } + def __call__(self, *args, **kwargs): + self.recursive_update_inputs(*args, **kwargs) + return self.get() + @staticmethod def function(op: _UnaryOp, operand: Any): return getattr(operator, op)(operand) diff --git a/src/nodify/tests/test_node.py b/src/nodify/tests/test_node.py index c6a4a4d..a0f745f 100644 --- a/src/nodify/tests/test_node.py +++ b/src/nodify/tests/test_node.py @@ -338,7 +338,7 @@ def my_node(**some_kwargs): assert node2._input_nodes["some_kwargs[a]"] is node3 -def test_ufunc(sum_node): +def test_op(sum_node): node = sum_node(1, 3) assert node.get() == 4 @@ -346,3 +346,16 @@ def test_ufunc(sum_node): node2 = node + 6 assert node2.get() == 10 + + +def test_op_update(sum_node): + + node = sum_node(1) + + result = node + 6 + + assert result(input2=3) == 10 + + result = abs(node) + + assert result(input2=-3) == 2 diff --git a/src/nodify/tests/test_workflow.py b/src/nodify/tests/test_workflow.py index a79a0e8..ec001ab 100644 --- a/src/nodify/tests/test_workflow.py +++ b/src/nodify/tests/test_workflow.py @@ -67,6 +67,16 @@ def triple_sum(a: int, b: int, c: int) -> int: return first_sum + c triple_sum._sum_key = "BinaryOperationNode" + elif request.param == "node_tree": + ... + + # def triple_sum(a, b, c): + # first_sum = a + b + # return first_sum + c + + # node = triple_sum(ConstantNode(2), ConstantNode(3), ConstantNode(5)) + + # triple_sum = Workflow.from_node_tree(node, ) return triple_sum