Skip to content

Commit

Permalink
Make operations update inputs again
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Jun 4, 2024
1 parent 1362083 commit ee1abf5
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name = "nodify"
description = "Supercharge your functional application with a powerful node system."
readme = "README.md"
license = {file = "LICENSE"}
version = "0.0.10"
version = "0.0.11"

dependencies = []

Expand Down
10 changes: 5 additions & 5 deletions src/nodify/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,15 +654,15 @@ def recursive_update_inputs(

def _update(node):
if cls is None or isinstance(self, cls):
node.update_inputs(**inputs)

sig = node.__signature__
update_inputs = {}
# Update the inputs of the node
for k in self.inputs:
if k in inputs:
for k in inputs:
if k in sig.parameters:
update_inputs[k] = inputs[k]

self.update_inputs(**update_inputs)
if len(update_inputs) > 0:
node.update_inputs(**update_inputs)

traverse_tree_backward([self], _update)

Expand Down
8 changes: 8 additions & 0 deletions src/nodify/syntax_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class BinaryOperationNode(Node):
"or": "|",
}

def __call__(self, *args, **kwargs):
self.recursive_update_inputs(*args, **kwargs)
return self.get()

@staticmethod
def function(left: Any, op: _BynaryOp, right: Any):
return getattr(operator, op)(left, right)
Expand All @@ -222,6 +226,10 @@ class UnaryOperationNode(Node):
"pos": "+",
}

def __call__(self, *args, **kwargs):
self.recursive_update_inputs(*args, **kwargs)
return self.get()

@staticmethod
def function(op: _UnaryOp, operand: Any):
return getattr(operator, op)(operand)
Expand Down
15 changes: 14 additions & 1 deletion src/nodify/tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,24 @@ def my_node(**some_kwargs):
assert node2._input_nodes["some_kwargs[a]"] is node3


def test_ufunc(sum_node):
def test_op(sum_node):
node = sum_node(1, 3)

assert node.get() == 4

node2 = node + 6

assert node2.get() == 10


def test_op_update(sum_node):

node = sum_node(1)

result = node + 6

assert result(input2=3) == 10

result = abs(node)

assert result(input2=-3) == 2
10 changes: 10 additions & 0 deletions src/nodify/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def triple_sum(a: int, b: int, c: int) -> int:
return first_sum + c

triple_sum._sum_key = "BinaryOperationNode"
elif request.param == "node_tree":
...

# def triple_sum(a, b, c):
# first_sum = a + b
# return first_sum + c

# node = triple_sum(ConstantNode(2), ConstantNode(3), ConstantNode(5))

# triple_sum = Workflow.from_node_tree(node, )

return triple_sum

Expand Down

0 comments on commit ee1abf5

Please sign in to comment.