-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Ricardo Vieira <[email protected]>
- Loading branch information
1 parent
18901ea
commit 9412b39
Showing
1 changed file
with
329 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pymc as pm\n", | ||
"import numpy as np\n", | ||
"import pytensor.tensor as pt\n", | ||
"import pytensor\n", | ||
"from pymc.logprob.basic import logp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x_rv = pt.expm1(pt.random.normal())\n", | ||
"\n", | ||
"x_vv = x_rv.clone()\n", | ||
"\n", | ||
"x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv)))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pytensor.tensor as pt\n", | ||
"import pytensor" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Elemwise(scalar_op=scalar_softplus,inplace_pattern=<frozendict {}>)" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"pt.log1pexp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"0.6931471805599453" | ||
] | ||
}, | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"def log(p = None):\n", | ||
" return np.log(1+p)\n", | ||
"log(True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"0.6931471805599453" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"np.log(2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"0.030929803620161386" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"np.log(np.cosh(0.25))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"y = np.exp(x)-1\n", | ||
"y + 1 = np.exp(x)\n", | ||
"np.log()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"2.993222846126381" | ||
] | ||
}, | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"np.arccosh(10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/var/folders/yw/hvmlf6sd7yv3040wn8str6540000gn/T/ipykernel_5322/294505405.py:1: RuntimeWarning: invalid value encountered in log\n", | ||
" np.log(-2)\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"nan" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"# get assertion error because we are caulting jac det directly therefore skipping this line\n", | ||
"# return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)\n", | ||
"# so is returned as nan rather than -inf as is expected\n", | ||
"# hence switch statement in the test\n", | ||
"\n", | ||
"# also in have changed line 416 to \n", | ||
"# return pt.switch(pt.isnan(input_logprob + jacobian), -np.inf, input_logprob + jacobian)\n", | ||
"# as cosh transform was returtuning nan due to input_logprob being nan\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 35, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import scipy.stats as stats\n", | ||
"\n", | ||
"def test_cosh_transform(test_val):\n", | ||
" x_rv = pt.abs(pt.random.normal())\n", | ||
" y_rv = pt.random.lognormal(sigma=1)\n", | ||
"\n", | ||
" x_vv = x_rv.clone()\n", | ||
" y_vv = y_rv.clone()\n", | ||
" x_transformed = pm.math.cosh(x_vv)\n", | ||
" y_transformed = pm.math.cosh(y_vv)\n", | ||
"\n", | ||
" x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_transformed}, sum=False))\n", | ||
" y_logp_fn = pytensor.function([y_vv], joint_logprob({y_rv: y_transformed}, sum=False))\n", | ||
"\n", | ||
" # Calculate the log probability using SciPy's lognorm distribution\n", | ||
" scipy_logp = stats.lognorm.logpdf(test_val, 1)\n", | ||
"\n", | ||
" assert np.allclose(x_logp_fn(test_val), y_logp_fn(test_val))\n", | ||
" assert np.allclose(y_logp_fn(test_val), scipy_logp)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pytensor\n", | ||
"import pytensor.tensor as at\n", | ||
"\n", | ||
"k = at.iscalar(\"k\")\n", | ||
"A = at.vector(\"A\")\n", | ||
"\n", | ||
"# Symbolic description of the result\n", | ||
"result, updates = pytensor.scan(fn=lambda prior_result, A: prior_result * A,\n", | ||
" outputs_info=at.ones_like(A),\n", | ||
" non_sequences=A,\n", | ||
" n_steps=k)\n", | ||
"\n", | ||
"# We only care about A**k, but scan has provided us with A**1 through A**k.\n", | ||
"# Discard the values that we don't care about. Scan is smart enough to\n", | ||
"# notice this and not waste memory saving them.\n", | ||
"final_result = result[-1]\n", | ||
"\n", | ||
"# compiled function that returns A**k\n", | ||
"power = pytensor.function(inputs=[A,k], outputs=final_result, updates=updates)\n", | ||
"\n", | ||
"print(power(range(10),2))\n", | ||
"print(power(range(10),4))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def erfcxinv(y, tol = 1e-10, max_iter=100):\n", | ||
" # Compute x using the appropriate formula for each value of y\n", | ||
" x = at.switch(y <= 1, 1. / (y * at.sqrt(np.pi)), -at.sqrt(at.log(y)))\n", | ||
" # for n in range(1, 8):\n", | ||
" # x = x - (at.erfcx(x) - y) / (2 * x * at.erfcx(x) - 2 / at.sqrt(np.pi))\n", | ||
" iter_count = 0\n", | ||
" while iter_count < max_iter:\n", | ||
" iter_count += 1\n", | ||
" fx = at.erfcx(x) - y\n", | ||
" fpx = 2 * x * at.erfcx(x) - 2 / at.sqrt(np.pi)\n", | ||
" delta_x = fx / fpx\n", | ||
" x = x - delta_x\n", | ||
" if (at.abs(delta_x) < tol).all():\n", | ||
" break\n", | ||
" return x \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Elemwise{sub,no_inplace}.0" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"erfcxinv(np.array([0.5, 0.5]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "pymc-dev-py39", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.0" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "bcf883f15876119fe9e03568c8cccc90efad6f040930a65a87ca95173e2637cf" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |