Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive logprob for exp2, log2, log10, log1p, expm1, log1mexp, log1pexp (softplus), and sigmoid transformations #6826

Merged
merged 27 commits into from
Sep 7, 2023

Conversation

LukeLB
Copy link
Contributor

@LukeLB LukeLB commented Jul 13, 2023

This builds upon the previous pull requests, #6664 and #6775, and completes the work of #6631.

I have attempted to rewrite the logp graph only for log1p, expm1, and log1pexp (softplus). My reasoning is that in these cases, the inputs are transformed directly without affecting the backward or log_jac_det transforms. For all other cases, I have made changes to the existing transform classes.

Although some tests are still failing (see below), I'm uncertain about the reasons for the first two failures. However, it seems that the final failure is due to a floating point error, and I'm unsure how to resolve it.

FAILED test_transforms.py::test_transformed_logprob[at_dist6-dist_params6-<lambda>-size6] - AssertionError: 
E           AssertionError: 
E           Arrays are not almost equal to 4 decimals
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 0.03384969
E           Max relative difference: 0.00578904
E            x: array(-5.8811)
E            y: array(-5.8472)
FAILED test_transforms.py::test_mixture_transform - assert False
E       assert False
E        +  where False = equal_computations([Sum{axes=None}.0], [Sum{axes=None}.0])
FAILED test_transforms.py::test_check_jac_det[transform9] - AssertionError: 
E           AssertionError: 
E           Not equal to tolerance rtol=1e-07, atol=0
E           
E           Mismatched elements: 1 / 2 (50%)
E           Max absolute difference: 6.59892563e-09
E           Max relative difference: 1.69410928e-07
E            x: array([ 1.05966 , -0.038952])
E            y: array([ 1.05966 , -0.038952])

Currently, there are no specific tests written for the transforms log1p, log1mexp, log1pexp (softplus), and sigmoid. I would appreciate assistance in developing tests for these transforms.
...

Checklist

Major / Breaking Changes

  • Changes to existing Exp and Log Transform classes

New features

  • The ability to derive log prob for the following transforms exp2, log2, log10, log1p, expm1, log1mexp, log1pexp (softplus), and sigmoid.

📚 Documentation preview 📚: https://pymc--6826.org.readthedocs.build/en/6826/

@codecov
Copy link

codecov bot commented Jul 13, 2023

Codecov Report

Merging #6826 (47f08ea) into main (ddd1d4b) will increase coverage by 0.10%.
Report is 8 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6826      +/-   ##
==========================================
+ Coverage   92.05%   92.16%   +0.10%     
==========================================
  Files          96      100       +4     
  Lines       16448    16877     +429     
==========================================
+ Hits        15142    15554     +412     
- Misses       1306     1323      +17     
Files Changed Coverage Δ
pymc/logprob/transforms.py 94.82% <100.00%> (+0.40%) ⬆️

... and 13 files with indirect coverage changes

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for picking this up. My suggestions are all about being a bit lazier (less code to maintain and test), but the idea is totally right!

pymc/logprob/transforms.py Outdated Show resolved Hide resolved
pymc/logprob/transforms.py Outdated Show resolved Hide resolved
pymc/logprob/transforms.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented Jul 14, 2023

My reasoning is that in these cases, the inputs are transformed directly without affecting the backward or log_jac_det transforms. For all other cases, I have made changes to the existing transform classes.

You don't need to worry about jacobian when rewriting into any equivalent forms, not just those. For instance for stuff like log2(c) -> log(c) / log(2), PyMC will have to introduce the right jacobian when deriving the probability of what it sees as a chained scale and log transformations. It doesn't matter what it was before

pymc/logprob/transforms.py Outdated Show resolved Hide resolved
[inp] = node.inputs

if isinstance(node.op.scalar_op, Exp2):
return [pt.power(2, inp)]
Copy link
Member

@ricardoV94 ricardoV94 Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we support this one, only powers with fixed exponent and variable base

Copy link
Member

@ricardoV94 ricardoV94 Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can convert to exp(ln(2)*x) instead, which PyMC will know how to handle

Copy link
Member

@ricardoV94 ricardoV94 Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should actually use that for any power(const, x) -> exp(log(const) * x) which we currently don't support. But maybe that's better left for another PR?

It requires checking we are interested in x and not const.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah I had a feeling that one may be an issue, I'll make the change. I'll add a new function to generalise this power(const, x) -> exp(log(const) * x) functionality.

Copy link
Contributor Author

@LukeLB LukeLB Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I've just pushed this new functionality. Do we care that it won't work for const <= 0?

Copy link
Member

@ricardoV94 ricardoV94 Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's open a separate issue / PR for this. I thought it over a bit and I think we can definitely support a couple of cases.

power(const, x), for any const > 0 and any x
power(const, x), for any const and discrete x  (we can play with `log(abs(neg_const))` and x's parity)

The first case we don't have to constrain ourselves to actual "constants", we can add a symbolic assert that const > 0.

The second requires us to implement transforms for discrete variables, which would probably need #6360 first, so we can focus on the first case, which is also probably more useful anyway.

We just have to make sure not to rewrite stuff like power(x, const) accidentally as those are implemented via our PowerTransform. This can be done by checking which of the inputs has a path to unvalued random variables.

@ricardoV94
Copy link
Member

Rebasing from main should unstuck the tests

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94
Copy link
Member

I imagine you pull --merge instead of pull --rebase? This makes github show unrelated commits in the PR :/

https://stackoverflow.com/a/71074635

@LukeLB
Copy link
Contributor Author

LukeLB commented Aug 4, 2023

Woops, sorry about that I will fix

@LukeLB LukeLB force-pushed the issue_#6631_chapter3_the_end branch from 0f50a55 to 0cfd1f2 Compare August 5, 2023 11:05
@LukeLB
Copy link
Contributor Author

LukeLB commented Aug 5, 2023

It seems that the code coverage test is failing because if the new node rewriting functions. Should we be testing them? Or are they fine because they rely on existing functionality thats already tested?

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 7, 2023

It seems that the code coverage test is failing because if the new node rewriting functions. Should we be testing them? Or are they fine because they rely on existing functionality thats already tested?

We should test them. We can test that the output logp is the same for the new cases we support and what we think they are equivalent to, using the equal_computations helper. Something like

def test():
  base_rv = pt.random.normal(name="base_rv")
  vv = pt.scalar("vv")
  
  logp_test = logp(pt.log1p(base_rv), vv)
  logp_ref = logp(pt.log(1 + base), vv)

  assert equal_computations([logp_test], [logp_ref])

@ricardoV94 ricardoV94 changed the title [WIP] Derive logprob for exp2, log2, log10, log1p, expm1, log1mexp, log1pexp (softplus), and sigmoid transformations Derive logprob for exp2, log2, log10, log1p, expm1, log1mexp, log1pexp (softplus), and sigmoid transformations Aug 7, 2023
@ricardoV94
Copy link
Member

@LukeLB shall we push the PR to the finish line? Let me know if you don't have the time right now

@LukeLB
Copy link
Contributor Author

LukeLB commented Aug 28, 2023

Hey really sorry about the lack of communication, I've been on holiday. I'm going to have a look at this this week.

@ricardoV94
Copy link
Member

No worries, I hope you had a good holiday!

@LukeLB
Copy link
Contributor Author

LukeLB commented Aug 29, 2023

Ahhh done the same thing as before and somehow merged unrelated commits! Will fix on my side

@LukeLB
Copy link
Contributor Author

LukeLB commented Aug 29, 2023

While I fix that I am having a problems with two of the tests. For the test

TRANSFORMATIONS = {
    "log1p": (pt.log1p, lambda x: pt.log(1 + x)),
    "softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
    "log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(pt.neg(x)))),
    "log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)),
    "log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)),
    "exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
    "expm1": (pt.expm1, lambda x: pt.exp(x) - 1),
    "sigmoid": (pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
}


@pytest.mark.parametrize("transform", TRANSFORMATIONS.keys())
def test_special_log_exp_transforms(transform):
    base_rv = pt.random.normal(name="base_rv")
    vv = pt.scalar("vv")

    transform_func, ref_func = TRANSFORMATIONS[transform]
    transformed_rv = transform_func(base_rv)
    ref_transformed_rv = ref_func(base_rv)

    logp_test = logp(transformed_rv, vv)
    logp_ref = logp(ref_transformed_rv, vv)

    assert equal_computations([logp_test], [logp_ref])

when transform is log2 or log10 then test fails for equal computation, I'm not sure what is causing that...

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 4, 2023

While I fix that I am having a problems with two of the tests. For the test

TRANSFORMATIONS = {
    "log1p": (pt.log1p, lambda x: pt.log(1 + x)),
    "softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
    "log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(pt.neg(x)))),
    "log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)),
    "log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)),
    "exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
    "expm1": (pt.expm1, lambda x: pt.exp(x) - 1),
    "sigmoid": (pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
}


@pytest.mark.parametrize("transform", TRANSFORMATIONS.keys())
def test_special_log_exp_transforms(transform):
    base_rv = pt.random.normal(name="base_rv")
    vv = pt.scalar("vv")

    transform_func, ref_func = TRANSFORMATIONS[transform]
    transformed_rv = transform_func(base_rv)
    ref_transformed_rv = ref_func(base_rv)

    logp_test = logp(transformed_rv, vv)
    logp_ref = logp(ref_transformed_rv, vv)

    assert equal_computations([logp_test], [logp_ref])

when transform is log2 or log10 then test fails for equal computation, I'm not sure what is causing that...

Try to look at pytensor.dprint(logp_test) and pytensor.dprint(logp_ref) and see if anything obvious comes across. Even something like x + 2 vs 2 + x would make that function return False

@LukeLB
Copy link
Contributor Author

LukeLB commented Sep 4, 2023

@ricardoV94 OK I think I've got to the bottom of what is causing the failure. Looks like a floating point in precision in one of the nodes of the graph between the test case and the reference case.
image
However I'm not sure I can control that in the original lambda function that I've written. Any ideas?

@ricardoV94
Copy link
Member

However I'm not sure I can control that in the original lambda function that I've written. Any ideas?

Hmm. Let's do a logp evaluation for that one (separate test) and check for output closeness?

@LukeLB LukeLB force-pushed the issue_#6631_chapter3_the_end branch from 17123ac to a912b20 Compare September 4, 2023 21:05
@ricardoV94
Copy link
Member

ricardoV94 commented Sep 5, 2023

File changes is showing some accidental overwriting of previous changes? https://github.com/pymc-devs/pymc/pull/6826/files

@ricardoV94 ricardoV94 merged commit 39a2975 into pymc-devs:main Sep 7, 2023
21 checks passed
@ricardoV94
Copy link
Member

Awesome work @LukeLB. Do you want to pursue the power exponent next?

@LukeLB
Copy link
Contributor Author

LukeLB commented Sep 7, 2023

Thanks @ricardoV94, I feel like I learnt a lot on this one!

Do you want to pursue the power exponent next?

Yep I'll start looking into that, if I run into any trouble I'll drop you a mesage on slack.

@twiecki
Copy link
Member

twiecki commented Sep 14, 2023

Congrats @LukeLB! This is a major and non-trivial contribution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants