Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: tracing variables within JointDistributions #1842

Open
chrism0dwk opened this issue Sep 25, 2024 · 0 comments
Open

RFC: tracing variables within JointDistributions #1842

chrism0dwk opened this issue Sep 25, 2024 · 0 comments

Comments

@chrism0dwk
Copy link
Contributor

chrism0dwk commented Sep 25, 2024

Background

Sometimes it would be useful to be able to trace intermediate values computed as part of a JointDistribution* model. The current solution to this is to make use of tfd.Deterministic. As an example, supposing we would like to trace the mean of a simple linear regression on a single feature, a user might write:

@tfd.JointDistributionCoroutine
def model():
    intercept = yield tfd.Normal(loc=0.0, scale=1.0, name="intercept")
    slope = yield tfd.Normal(loc=0.0, scale=1.0, name="slope")

    mean = tfd.Deterministic(intercept + slope * feature, name="mean")

    yield tfd.Normal(loc=mean, scale=1.0, name="response")

A call to model.sample() will return a named tuple including the value of mean, conditional on alpha and beta (and feature). However, if we wish to compute the log probability density of the model given response, intercept, and slope, we also have to pass into model.log_prob a value of mean. Here, mean must be consistent with intercept and slope, which requires the user to duplicate the expression for mean outside the model object. e.g.

intercept = 0.1
slope = 0.2
feature = 0.5
response = 0.21

# Since `mean` is deterministic, we should not have to re-compute it outside of `model`
mean = intercept + slope * feature
lp = model.log_prob(intercept=intercept, slope=slope, mean=mean, response=response)

This seems wasteful in terms of keystrokes, but also error-prone if model changes.

Suggested solution

A potential solution would be to include a sub-class similar to JointDistribution.Root called JointDistribution.Trace which would flag an expression for tracing in the forward generating process (i.e. model.sample()), but exclude the associated variable from the CDF/CMF and PDF/PMF-related methods. Thus we could write:

@tfd.JointDistributionCoroutine
def model():
    intercept = yield tfd.Normal(loc=0.0, scale=1.0, name="intercept")
    slope = yield tfd.Normal(loc=0.0, scale=1.0, name="slope")

    mean =Trace(intercept + slope * feature, name="mean")

    yield tfd.Normal(loc=mean, scale=1.0, name="response")

draw = model.sample(seed=[0,0])

# `mean` is simply ignored
model.log_prob(draw)

# `mean` does not have to be supplied
model.log_prob(intercept=draw.intercept, slope=draw.slope, response=draw.response) 

Does this seem like a feasible addition? (I may have some resource to devote to it)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant