Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support logp derivation of power exponent #6896

Closed
ricardoV94 opened this issue Sep 7, 2023 · 8 comments · Fixed by #6962
Closed

Support logp derivation of power exponent #6896

ricardoV94 opened this issue Sep 7, 2023 · 8 comments · Fixed by #6962

Comments

@ricardoV94
Copy link
Member

power(const, x), for any const > 0 and any x
power(const, x), for any const and discrete x  (we can play with `log(abs(neg_const))` and x's parity)

The first case we don't have to constrain ourselves to actual "constants", we can add a symbolic assert that const > 0.

The second requires us to implement transforms for discrete variables, which would probably need #6360 first, so we can focus on the first case, which is also probably more useful anyway.

We just have to make sure not to rewrite stuff like power(x, const) accidentally as those are implemented via our PowerTransform. This can be done by checking which of the inputs has a path to unvalued random variables.

Originally posted by @ricardoV94 in #6826 (comment)

@LukeLB
Copy link
Contributor

LukeLB commented Oct 12, 2023

@ricardoV94 I've got the first case working locally now but I don't think its exactly what you asked for as I don't understand what you mean by

we don't have to constrain ourselves to actual "constants", we can add a symbolic assert that const > 0

Could you provide an example?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 16, 2023

When you have a graph like:

base_rv = pm.Poisson.dist()
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)

x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()

conditional_logprob({base: base_vv, x: x_vv})

In that case base_rv is not a "constant", but x_rv is conditionally independent and the logp should be fine as long as base_vv is never given a negative value. We can add a ParameterValueCheck which will raise when this happens.

@LukeLB
Copy link
Contributor

LukeLB commented Oct 16, 2023

Am I using this correctly?

@node_rewriter([pow, CheckParameterValue])
def measurable_power_expotent_to_exp(fgraph, node):
    exponent, inp = node.inputs
    
    # check whether inp is discrete
    if inp.type.dtype.startswith("int"):
        return None

    return [pt.exp(pt.log(exponent) * inp)]

Because while this works fine

from pymc.logprob.basic import conditional_logp
base_rv = pm.Poisson.dist([1,1])
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)

x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()

res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
res_combined = pt.sum([factor for factor in res.values()])
logp_vals_fn = pytensor.function([base_vv, x_vv], res_combined)

logp_vals_fn(np.array([2, 2]), np.array([2,2]))

>>> array(-6.87743995)

If I understand you properly I would of thought this shouldn't work as base_rv can take on negative values, but the log prob does evaluate

from pymc.logprob.basic import conditional_logp
base_rv = pm.Normal.dist([1,1])
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)

x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()

res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
res_combined = pt.sum([factor for factor in res.values()])
logp_vals_fn = pytensor.function([base_vv, x_vv], res_combined)

logp_vals_fn(np.array([2, 2]), np.array([2,2]))

>>> array(-6.32902265)

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 16, 2023

The rewrite should only be used when the exponent (not the base) is being measured.

You can get this info from somewhere in fgraph.preserve_rv_mappings.

Otherwise you would be breaking logp inference for stuff like base_rv ** (-1) where the base is the thing we want to measure.

If I understand you properly I would of thought this shouldn't work as base_rv can take on negative values, but the log prob does evaluate

It's fine as long as it does something sensible when negative values are passed (hence my suggestion of wrapping the exponent in a CheckParameterValue).

What does it evaluate to now when you pass a negative base? I guess you get a nan. You don't need to sum the two factors, just check out the one for the exponent RV.

@LukeLB
Copy link
Contributor

LukeLB commented Oct 18, 2023

The rewrite should only be used when the exponent (not the base) is being measured.

My changes haven't broken power(x, const), as I have also made a change to find_measurable_transforms to account for power(const, x). This may be easier for me to show in a PR so you can see all the changes I have made.

It's fine as long as it does something sensible when negative values are passed (hence my suggestion of wrapping the exponent in a CheckParameterValue).

Ah so I was being thick and using CheckParameterValue wrong i have changed to,

@node_rewriter([pow])
def measurable_power_expotent_to_exp(fgraph, node):
    base, inp_exponent = node.inputs
    base = CheckParameterValue("base > 0")(base, pt.all(pt.ge(base, 0.0)))

    # check whether inp is discrete
    if inp_exponent.type.dtype.startswith("int"):
        return None
    
    return [pt.exp(pt.log(base) * inp_exponent)]

When base is a negative constant,

x_rv = pt.pow(-1, pt.random.normal())

x_vv = x_rv.clone()

x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv)))
x_logp_fn(0.1)

ParameterValueError: base > 0 is raised as expected.

When base is a random variable with negative values,

base_rv = pm.Normal.dist([2])
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)

x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()

res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
factors = [factor for factor in res.values()]
logp_vals_fn = pytensor.function([base_vv, x_vv], factors[1])

logp_vals_fn(np.array([2]), np.array([2]))

>>> array([-1.74557279])

In this case it evaluates fine, is this what you would expect? Or does the CheckParameterValue logic need to change?

@ricardoV94
Copy link
Member Author

In this case it evaluates fine, is this what you would expect? Or does the CheckParameterValue logic need to change?

That's what I would expect because you provided a positive value to the base_vv. It should fail if you call logp_vals_fn(np.array([-2]), np.array([2]))

@ricardoV94
Copy link
Member Author

My changes haven't broken power(x, const), as I have also made a change to find_measurable_transforms to account for power(const, x). This may be easier for me to show in a PR so you can see all the changes I have made.

Great, I'll take a look when you open the PR

@LukeLB
Copy link
Contributor

LukeLB commented Oct 18, 2023

In this case it evaluates fine, is this what you would expect? Or does the CheckParameterValue logic need to change?

That's what I would expect because you provided a positive value to the base_vv. It should fail if you call logp_vals_fn(np.array([-2]), np.array([2]))

Yep when I do that it fails as expected! I'll open a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants