Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement easier access to and manipulation of var inputs #202

Merged
merged 21 commits into from
Oct 18, 2024

Conversation

jobrachem
Copy link
Contributor

@jobrachem jobrachem commented Aug 17, 2024

This PR contains only a few lines of changed code, but I think it can drastically improve quality of life while building and manipulating Liesel models.

Problem statement

I have found the work with variable inputs to be often quite cumbersome. I'll show what I mean in examples.

import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd
import jax.numpy as jnp

Example 1: Accessing parameters of a distribution

Let's say I want to access the loc and scale of a variable's distribution, starting from the variable.

def create_var1():
    loc = lsl.Var(0.0, name="loc")
    scale = lsl.Var(1.0, name="scale")
    y = lsl.obs(jnp.zeros(10), lsl.Dist(tfd.Normal, loc=loc, scale=scale), name="y")
    return y

y1 = create_var1()

The Var.dist_node's kwinputs attribute does not give me that access. Instead, it returns a dict of VarValue nodes. That is node helpful.

>>> y1.dist_node.kwinputs
mappingproxy({'loc': VarValue(name="loc_var_value"),
              'scale': VarValue(name="scale_var_value")})

To actually get to the loc var, I have to call:

>>> y1.dist_node.kwinputs["loc"].var
Var(name="loc")

I think this is unnecessarily cumbersome. I have to actually know a lot about Liesel's internals and/or dig in the source code to find what I need. Every time I need to perform this action, I have to look up what to do and extensively test my code in order to be sure that I get it right. It is quite annoying.

Example 2: Accessing inputs of a calculator

A very similar pattern holds when you use calculators:

def create_var2():
    x = lsl.Var(0.0, name="x")
    y = lsl.Var(lsl.Calc(lambda x: x + 1, x), name="y")
    return y

y2 = create_var2()
>>> y2.value_node.inputs
(VarValue(name="x_var_value"),)
>>> y2.value_node.inputs[0].var
Var(name="x")

Proposed Solution

I implemented Node.__getitem__ and Var.__getitem__ as a remedy.

The above tasks can now be solved like this:

>>> y1.dist_node["loc"] # access by arg name, here equivalent to y1.dist_node[0]
Var(name="loc")
>>> y2[0] # access by index, equivalent to y2.value_node[0]
Var(name="x")

Some details

  1. The basic implementation is done in Node.__getitem__.
    a) If it receives an integer, it will essentially look up the searched item in Node.all_input_nodes(). This will find all inputs, including positional and keyword inputs.
    b) If it receives a string, it will essentially look up the searched item in Node.kwinputs. This will of course find only inputs that are actually keyword inputs.
  2. Var.__getitem__ will defer to its value node. To access inputs to the distribution, users can turn to Var.dist_node.__getitem__.

Setitem

The implementation also provides the possibility to replace inputs via Node.__setitem__. Example:

# before
>>> y2[0]
Var(name="x")
>>> y2.value
1.0

# change input
>>> y2[0] = lsl.Var(3.0, name="new_input")

# after
>>> y2[0]
Var(name="new_input")
>>> y2.value
4.0

The same works for Dist. The implementation is a thin quality-of-life wrapper around the existing Node.set_inputs() method.

Documentation

This is a draft PR for a first discussion. Even if it remains unchanged, documentation has to be added if it ends up being merged.

@jobrachem jobrachem self-assigned this Aug 26, 2024
@jobrachem
Copy link
Contributor Author

Discussion Notes

  • There's doubt on whether lsl.Var.__getitem__ should be able to look up keyword arguments as well as positional arguments when using an integer.
  • There's doubt on whether lsl.Var.__getitem__ should be implemented - a case can be made for just sticking to lsl.Var.value_node.__getitem__.

@jobrachem
Copy link
Contributor Author

jobrachem commented Sep 10, 2024

With the current update I added the following changes:

  • Removed the indexing functionality from Var
  • Restricted integer indexing functionality to access only positional inputs.
  • Made the underlying helpers private

This means we now have these possibilities:

Accessing inputs

Access named value inputs to a calculator

a = Var(2.0, name="a")
y = Var(Calc(lambda x: x + 1.0, x=a))
y.value_node["x"]

Access positional value inputs to a calculator

a = Var(2.0, name="a")
y = Var(Calc(lambda x: x + 1.0, a))
y.value_node[0]

Access named value inputs to a dist

a = Var(2.0, name="a")
y = Var(1.0, Dist(tfp.distributions.Normal, loc=a, scale=1.0))
y.dist_node["loc"]

Access positional value inputs to a dist

a = Var(2.0, name="a")
y = Var(1.0, Dist(tfp.distributions.Normal, a, scale=1.0))
y.dist_node[0]

Swapping out inputs

Swap out inputs to a calculator

Note:

  • Works by integer indexing for positional arguments.
  • Can only be used for updating existing inputs
  • Relies on existing functionality in Node.set_inputs, but makes it easier to switch out one particular input. Node.set_inputs requires users to provide all positional and keyword inputs, even if only one input is supposed to be updated.
a = Var(2.0, name="a")
y = Var(Calc(lambda x: x + 1.0, x=a))

b = Var(3.0, name="b")
y.value_node["x"] = b

Swap out inputs to a dist

a = Var(2.0, name="a")
y = Var(1.0, Dist(tfp.distributions.Normal, loc=a, scale=1.0))

b = Var(3.0, name="b")
y.dist_node["loc"] = b

What would be left to do

If this PR moves forward, this is left to do:

  • Update documentation

@jobrachem
Copy link
Contributor Author

Cases to cover in tests:

  • What happens when using del?
  • What happens when trying to assign None?
  • What happens with variable-args-input calculator functions? e.g. sum(*args)
  • What happens when we have functions with default arguments and then add or remove an input node that refers to this argument?

@jobrachem
Copy link
Contributor Author

jobrachem commented Sep 11, 2024

What happens when using __del__?

loc = Var(0.0)
scale = Var(1.0)
x = Calc(lambda loc, scale: loc * scale, loc, scale)
del x[0]

tries to call

x.__delitem__(0)

which results in an AttributeError, since we do not implement Node.__delitem__. I would say this is expected behavior, because the intention is to provide users with a way to swap out inputs, not to remove inputs without replacement.

What happens when assigning None?

loc = Var(0.0)
scale = Var(1.0)
x = Calc(lambda loc, scale: loc * scale, loc, scale)
x[0] = None

is equivalent to

x[0] = Value(None)

it does not immediately result in an error, since we allow nodes to have a value of None. I think this is expected behavior.

What happens with variable-args-input calculator functions? e.g. sum(*args)

The number of inputs gets fixed during node initialization and cannot be changed later through Node.__setitem__. I think this is expected behavior. The equivalent is true for variable keyword-inputs functions.

def sum_(*args):
    return sum(args)

x = Calc(sum_, 1.0, 2.0)

assert x.value == pytest.approx(3.0)

x[0] = 2.0
x.update()
assert x.value == pytest.approx(4.0)

with pytest.raises(IndexError):
    x[2] = 3.0

What happens when we have functions with default arguments?

Similar to the above: The inputs get fixed during node initialization and cannot be changed later. I think this is expected behavior.

def sum_(a, b=3.0):
    return a + b

x = Calc(sum_, 1.0)
assert x.value == pytest.approx(4.0)

x[0] = 2.0
x.update()
assert x.value == pytest.approx(5.0)

with pytest.raises(IndexError):
    x[1] = 0.0

with pytest.raises(KeyError):
    x["b"] = 0.0

@wiep
Copy link
Contributor

wiep commented Sep 20, 2024

@jobrachem, should this already be reviewed or discussed next week?

@jobrachem
Copy link
Contributor Author

@wiep I need to add documentation before this can be reviewed.

Also, the test actions currently time out, which is not nice for the review process. The timeout seems to be unrelated to our code changes, but the amount of additional time that the tests seem to use is quite drastic.

@jobrachem
Copy link
Contributor Author

See #213 regarding the tests timing out.

@jobrachem jobrachem marked this pull request as ready for review September 25, 2024 09:25
@jobrachem
Copy link
Contributor Author

@wiep @GianmarcoCallegher the merge conflicts are resolved and the tests for circular graph behavior updated :) Ready for review!

@jobrachem
Copy link
Contributor Author

@wiep @GianmarcoCallegher gentle reminder :) Would be meaningful for me to have this soon, since I want to use it in the current semester in teaching.

Copy link
Contributor

@GianmarcoCallegher GianmarcoCallegher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Just a little question

@GianmarcoCallegher
Copy link
Contributor

It's ready to be merged for me

Copy link
Contributor

@wiep wiep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ready to be merged. Just two comments. One one wording in documentation and the second on more expressive names for tests. Feel free to ignore them if you disagree.

tests/model/test_node.py Outdated Show resolved Hide resolved
liesel/model/nodes.py Outdated Show resolved Hide resolved
@jobrachem jobrachem merged commit a467eec into main Oct 18, 2024
4 checks passed
@jobrachem jobrachem deleted the access_inputs branch October 18, 2024 07:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants