From 9e3ce1fd1ff8538cc186ddc5385c5fc88289967e Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Thu, 9 May 2024 08:29:52 +0530 Subject: [PATCH 1/6] Added Gruntz Demo 1 (Overview of the whole algorithm) Added Gruntz Demo 2 (Overview of the cases we are addressing i.e. cases where most rapidly varying term is linear in nature) Porting Gruntz Demo 2 to LPython and storing in Gruntz Demo 3 --- integration_tests/gruntz_demo.py | 399 +++++++++++++++++++++++++++++++ 1 file changed, 399 insertions(+) create mode 100644 integration_tests/gruntz_demo.py diff --git a/integration_tests/gruntz_demo.py b/integration_tests/gruntz_demo.py new file mode 100644 index 0000000000..9fbe8d5983 --- /dev/null +++ b/integration_tests/gruntz_demo.py @@ -0,0 +1,399 @@ +""" +Limits +====== + +Implemented according to the PhD thesis +https://www.cybertester.com/data/gruntz.pdf, which contains very thorough +descriptions of the algorithm including many examples. We summarize here +the gist of it. + +All functions are sorted according to how rapidly varying they are at +infinity using the following rules. Any two functions f and g can be +compared using the properties of L: + +L=lim log|f(x)| / log|g(x)| (for x -> oo) + +We define >, < ~ according to:: + + 1. f > g .... L=+-oo + + we say that: + - f is greater than any power of g + - f is more rapidly varying than g + - f goes to infinity/zero faster than g + + 2. f < g .... L=0 + + we say that: + - f is lower than any power of g + + 3. f ~ g .... L!=0, +-oo + + we say that: + - both f and g are bounded from above and below by suitable integral + powers of the other + +Examples +======== +:: + 2 < x < exp(x) < exp(x**2) < exp(exp(x)) + 2 ~ 3 ~ -5 + x ~ x**2 ~ x**3 ~ 1/x ~ x**m ~ -x + exp(x) ~ exp(-x) ~ exp(2x) ~ exp(x)**2 ~ exp(x+exp(-x)) + f ~ 1/f + +So we can divide all the functions into comparability classes (x and x^2 +belong to one class, exp(x) and exp(-x) belong to some other class). In +principle, we could compare any two functions, but in our algorithm, we +do not compare anything below the class 2~3~-5 (for example log(x) is +below this), so we set 2~3~-5 as the lowest comparability class. + +Given the function f, we find the list of most rapidly varying (mrv set) +subexpressions of it. This list belongs to the same comparability class. +Let's say it is {exp(x), exp(2x)}. Using the rule f ~ 1/f we find an +element "w" (either from the list or a new one) from the same +comparability class which goes to zero at infinity. In our example we +set w=exp(-x) (but we could also set w=exp(-2x) or w=exp(-3x) ...). We +rewrite the mrv set using w, in our case {1/w, 1/w^2}, and substitute it +into f. Then we expand f into a series in w:: + + f = c0*w^e0 + c1*w^e1 + ... + O(w^en), where e0oo, lim f = lim c0*w^e0, because all the other terms go to zero, +because w goes to zero faster than the ci and ei. So:: + + for e0>0, lim f = 0 + for e0<0, lim f = +-oo (the sign depends on the sign of c0) + for e0=0, lim f = lim c0 + +We need to recursively compute limits at several places of the algorithm, but +as is shown in the PhD thesis, it always finishes. + +Important functions from the implementation: + +compare(a, b, x) compares "a" and "b" by computing the limit L. +mrv(e, x) returns list of most rapidly varying (mrv) subexpressions of "e" +rewrite(e, Omega, x, wsym) rewrites "e" in terms of w +leadterm(f, x) returns the lowest power term in the series of f +mrv_leadterm(e, x) returns the lead term (c0, e0) for e +limitinf(e, x) computes lim e (for x->oo) +limit(e, z, z0) computes any limit by converting it to the case x->oo + +All the functions are really simple and straightforward except +rewrite(), which is the most difficult/complex part of the algorithm. +When the algorithm fails, the bugs are usually in the series expansion +(i.e. in SymPy) or in rewrite. + +This code is almost exact rewrite of the Maple code inside the Gruntz +thesis. + +Debugging +--------- + +Because the gruntz algorithm is highly recursive, it's difficult to +figure out what went wrong inside a debugger. Instead, turn on nice +debug prints by defining the environment variable SYMPY_DEBUG. For +example: + +[user@localhost]: SYMPY_DEBUG=True ./bin/isympy + +In [1]: limit(sin(x)/x, x, 0) +limitinf(_x*sin(1/_x), _x) = 1 ++-mrv_leadterm(_x*sin(1/_x), _x) = (1, 0) +| +-mrv(_x*sin(1/_x), _x) = set([_x]) +| | +-mrv(_x, _x) = set([_x]) +| | +-mrv(sin(1/_x), _x) = set([_x]) +| | +-mrv(1/_x, _x) = set([_x]) +| | +-mrv(_x, _x) = set([_x]) +| +-mrv_leadterm(exp(_x)*sin(exp(-_x)), _x, set([exp(_x)])) = (1, 0) +| +-rewrite(exp(_x)*sin(exp(-_x)), set([exp(_x)]), _x, _w) = (1/_w*sin(_w), -_x) +| +-sign(_x, _x) = 1 +| +-mrv_leadterm(1, _x) = (1, 0) ++-sign(0, _x) = 0 ++-limitinf(1, _x) = 1 + +And check manually which line is wrong. Then go to the source code and +debug this function to figure out the exact problem. + +""" +from functools import reduce + +from sympy.core import Basic, S, Mul, PoleError, expand_mul, evaluate +from sympy.core.cache import cacheit +from sympy.core.numbers import I, oo +from sympy.core.symbol import Dummy, Wild, Symbol +from sympy.core.traversal import bottom_up +from sympy.core.sorting import ordered + +from sympy.functions import log, exp, sign, sin +from sympy.series.order import Order +from sympy.utilities.exceptions import SymPyDeprecationWarning +from sympy.utilities.misc import debug_decorator as debug +from sympy.utilities.timeutils import timethis + +def mrv(e, x): + """ + Calculate the MRV set of the expression. + + Examples + ======== + + >>> mrv(log(x - log(x))/log(x), x) + {x} + + """ + + if not e.has(x): + return set() + if e == x: + return {x} + if e.is_Mul or e.is_Add: + a, b = e.as_two_terms() + return mrv_max(mrv(a, x), mrv(b, x), x) + if e.func == exp: + if e.exp == x: + return {e} + if any(a.is_infinite for a in Mul.make_args(limitinf(e.exp, x))): + return mrv_max({e}, mrv(e.exp, x), x) + return mrv(e.exp, x) + if e.is_Pow: + return mrv(e.base, x) + if isinstance(e, log): + return mrv(e.args[0], x) + if e.is_Function: + return reduce(lambda a, b: mrv_max(a, b, x), (mrv(a, x) for a in e.args)) + raise NotImplementedError(f"Can't calculate the MRV of {e}.") + +def mrv_max(f, g, x): + """Compute the maximum of two MRV sets. + + Examples + ======== + + >>> mrv_max({log(x)}, {x**5}, x) + {x**5} + + """ + + if not f: + return g + if not g: + return f + if f & g: + return f | g + + a, b = map(next, map(iter, (f, g))) + + # The log(exp(...)) must always be simplified here. + la = a.exp if a.is_Exp else log(a) + lb = b.exp if b.is_Exp else log(b) + + c = limitinf(la/lb, x) + if c.is_zero: + return g + if c.is_infinite: + return f + return f | g + +def rewrite(e, x, w): + r""" + Rewrites the expression in terms of the MRV subexpression. + + Parameters + ========== + + e : Expr + an expression + x : Symbol + variable of the `e` + w : Symbol + The symbol which is going to be used for substitution in place + of the MRV in `x` subexpression. + + Returns + ======= + + tuple + A pair: rewritten (in `w`) expression and `\log(w)`. + + Examples + ======== + + >>> rewrite(exp(x)*log(x), x, y) + (log(x)/y, -x) + + """ + + Omega = mrv(e, x) + if not Omega: + return e, None # e really does not depend on x + + if x in Omega: + # Moving up in the asymptotical scale: + with evaluate(False): + e = e.xreplace({x: exp(x)}) + Omega = {s.xreplace({x: exp(x)}) for s in Omega} + + Omega = list(ordered(Omega, keys=lambda a: -len(mrv(a, x)))) + + for g in Omega: + sig = signinf(g.exp, x) + if sig not in (1, -1): + raise NotImplementedError(f'Result depends on the sign of {sig}.') + + if sig == 1: + w = 1/w # if g goes to oo, substitute 1/w + + # Rewrite and substitute subexpressions in the Omega. + for a in Omega: + c = limitinf(a.exp/g.exp, x) + b = exp(a.exp - c*g.exp)*w**c # exponential must never be expanded here + with evaluate(False): + e = e.xreplace({a: b}) + + return e, -sig*g.exp + +@cacheit +def mrv_leadterm(e, x): + """ + Compute the leading term of the series. + + Returns + ======= + + tuple + The leading term `c_0 w^{e_0}` of the series of `e` in terms + of the most rapidly varying subexpression `w` in form of + the pair ``(c0, e0)`` of Expr. + + Examples + ======== + + >>> leadterm(1/exp(-x + exp(-x)) - exp(x), x) + (-1, 0) + + """ + + if not e.has(x): + return e, Integer(0) + + # Rewrite to exp-log functions per Sec. 3.3 of thesis. + e = e.replace(lambda f: f.is_Pow and f.exp.has(x), + lambda f: exp(log(f.base)*f.exp)) + e = e.replace(lambda f: f.is_Mul and sum(a.func == exp for a in f.args) > 1, + lambda f: Mul(exp(Add(*(a.exp for a in f.args if a.func == exp))), + *(a for a in f.args if not a.func == exp))) + + # The positive dummy, w, is used here so log(w*2) etc. will expand. + # TODO: For limits of complex functions, the algorithm would have to + # be improved, or just find limits of Re and Im components separately. + w = Dummy('w', real=True, positive=True) + e, logw = rewrite(e, x, w) + + c0, e0 = e.leadterm(w, logx=logw) + if c0.has(w): + raise NotImplementedError(f'Cannot compute leadterm({e}, {x}). ' + 'The coefficient should have been free of ' + f'{w}, but got {c0}.') + return c0.subs(log(w), logw), e0 + +@cacheit +def signinf(e, x): + r""" + Determine sign of the expression at the infinity. + + Returns + ======= + + {1, 0, -1} + One or minus one, if `e > 0` or `e < 0` for `x` sufficiently + large and zero if `e` is *constantly* zero for `x\to\infty`. + + """ + + if not e.has(x): + return sign(e).simplify() + if e == x or (e.is_Pow and signinf(e.base, x) == 1): + return S(1) + if e.is_Mul: + a, b = e.as_two_terms() + return signinf(a, x)*signinf(b, x) + + c0, _ = leadterm(e, x) + return signinf(c0, x) + +@cacheit +def limitinf(e, x): + """ + Compute the limit of the expression at the infinity. + + Examples + ======== + + >>> limitinf(exp(x)*(exp(1/x - exp(-x)) - exp(1/x)), x) + -1 + + """ + # Rewrite e in terms of tractable functions only: + e = e.rewrite('tractable', deep=True, limitvar=x) + + if not e.has(x): + return e.rewrite('intractable', deep=True) + + c0, e0 = mrv_leadterm(e, x) + sig = signinf(e0, x) + if sig == 1: + return Integer(0) + if sig == -1: + return signinf(c0, x)*oo + if sig == 0: + return limitinf(c0, x) + raise NotImplementedError(f'Result depends on the sign of {sig}.') + + +def gruntz(e, z, z0, dir="+"): + """ + Compute the limit of e(z) at the point z0 using the Gruntz algorithm. + + Explanation + =========== + + ``z0`` can be any expression, including oo and -oo. + + For ``dir="+"`` (default) it calculates the limit from the right + (z->z0+) and for ``dir="-"`` the limit from the left (z->z0-). For infinite z0 + (oo or -oo), the dir argument does not matter. + + This algorithm is fully described in the module docstring in the gruntz.py + file. It relies heavily on the series expansion. Most frequently, gruntz() + is only used if the faster limit() function (which uses heuristics) fails. + """ + if not z.is_symbol: + raise NotImplementedError("Second argument must be a Symbol") + + # convert all limits to the limit z->oo; sign of z is handled in limitinf + r = None + if z0 in (oo, I*oo): + e0 = e + elif z0 in (-oo, -I*oo): + e0 = e.subs(z, -z) + else: + if str(dir) == "-": + e0 = e.subs(z, z0 - 1/z) + elif str(dir) == "+": + e0 = e.subs(z, z0 + 1/z) + else: + raise NotImplementedError("dir must be '+' or '-'") + + r = limitinf(e0, z) + + # This is a bit of a heuristic for nice results... we always rewrite + # tractable functions in terms of familiar intractable ones. + # It might be nicer to rewrite the exactly to what they were initially, + # but that would take some work to implement. + return r.rewrite('intractable', deep=True) + +# tests +x = Symbol('x') +ans = gruntz(sin(x)/x, x, 0) +print(ans) \ No newline at end of file From e3a4a657ebfd37d09230e2fe45746a5cdee6e33f Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Fri, 10 May 2024 08:55:33 +0530 Subject: [PATCH 2/6] Step 2 --- integration_tests/gruntz_demo.py | 2 +- integration_tests/gruntz_demo2.py | 334 ++++++++++++++++++++++++++++++ 2 files changed, 335 insertions(+), 1 deletion(-) create mode 100644 integration_tests/gruntz_demo2.py diff --git a/integration_tests/gruntz_demo.py b/integration_tests/gruntz_demo.py index 9fbe8d5983..6beb38f330 100644 --- a/integration_tests/gruntz_demo.py +++ b/integration_tests/gruntz_demo.py @@ -343,7 +343,7 @@ def limitinf(e, x): c0, e0 = mrv_leadterm(e, x) sig = signinf(e0, x) if sig == 1: - return Integer(0) + return S(0) if sig == -1: return signinf(c0, x)*oo if sig == 0: diff --git a/integration_tests/gruntz_demo2.py b/integration_tests/gruntz_demo2.py new file mode 100644 index 0000000000..cc392e5b0f --- /dev/null +++ b/integration_tests/gruntz_demo2.py @@ -0,0 +1,334 @@ +""" +Limits +====== + +Implemented according to the PhD thesis +https://www.cybertester.com/data/gruntz.pdf, which contains very thorough +descriptions of the algorithm including many examples. We summarize here +the gist of it. + +All functions are sorted according to how rapidly varying they are at +infinity using the following rules. Any two functions f and g can be +compared using the properties of L: + +L=lim log|f(x)| / log|g(x)| (for x -> oo) + +We define >, < ~ according to:: + + 1. f > g .... L=+-oo + + we say that: + - f is greater than any power of g + - f is more rapidly varying than g + - f goes to infinity/zero faster than g + + 2. f < g .... L=0 + + we say that: + - f is lower than any power of g + + 3. f ~ g .... L!=0, +-oo + + we say that: + - both f and g are bounded from above and below by suitable integral + powers of the other + +Examples +======== +:: + 2 < x < exp(x) < exp(x**2) < exp(exp(x)) + 2 ~ 3 ~ -5 + x ~ x**2 ~ x**3 ~ 1/x ~ x**m ~ -x + exp(x) ~ exp(-x) ~ exp(2x) ~ exp(x)**2 ~ exp(x+exp(-x)) + f ~ 1/f + +So we can divide all the functions into comparability classes (x and x^2 +belong to one class, exp(x) and exp(-x) belong to some other class). In +principle, we could compare any two functions, but in our algorithm, we +do not compare anything below the class 2~3~-5 (for example log(x) is +below this), so we set 2~3~-5 as the lowest comparability class. + +Given the function f, we find the list of most rapidly varying (mrv set) +subexpressions of it. This list belongs to the same comparability class. +Let's say it is {exp(x), exp(2x)}. Using the rule f ~ 1/f we find an +element "w" (either from the list or a new one) from the same +comparability class which goes to zero at infinity. In our example we +set w=exp(-x) (but we could also set w=exp(-2x) or w=exp(-3x) ...). We +rewrite the mrv set using w, in our case {1/w, 1/w^2}, and substitute it +into f. Then we expand f into a series in w:: + + f = c0*w^e0 + c1*w^e1 + ... + O(w^en), where e0oo, lim f = lim c0*w^e0, because all the other terms go to zero, +because w goes to zero faster than the ci and ei. So:: + + for e0>0, lim f = 0 + for e0<0, lim f = +-oo (the sign depends on the sign of c0) + for e0=0, lim f = lim c0 + +We need to recursively compute limits at several places of the algorithm, but +as is shown in the PhD thesis, it always finishes. + +Important functions from the implementation: + +compare(a, b, x) compares "a" and "b" by computing the limit L. +mrv(e, x) returns list of most rapidly varying (mrv) subexpressions of "e" +rewrite(e, Omega, x, wsym) rewrites "e" in terms of w +leadterm(f, x) returns the lowest power term in the series of f +mrv_leadterm(e, x) returns the lead term (c0, e0) for e +limitinf(e, x) computes lim e (for x->oo) +limit(e, z, z0) computes any limit by converting it to the case x->oo + +All the functions are really simple and straightforward except +rewrite(), which is the most difficult/complex part of the algorithm. +When the algorithm fails, the bugs are usually in the series expansion +(i.e. in SymPy) or in rewrite. + +This code is almost exact rewrite of the Maple code inside the Gruntz +thesis. + +Debugging +--------- + +Because the gruntz algorithm is highly recursive, it's difficult to +figure out what went wrong inside a debugger. Instead, turn on nice +debug prints by defining the environment variable SYMPY_DEBUG. For +example: + +[user@localhost]: SYMPY_DEBUG=True ./bin/isympy + +In [1]: limit(sin(x)/x, x, 0) +limitinf(_x*sin(1/_x), _x) = 1 ++-mrv_leadterm(_x*sin(1/_x), _x) = (1, 0) +| +-mrv(_x*sin(1/_x), _x) = set([_x]) +| | +-mrv(_x, _x) = set([_x]) +| | +-mrv(sin(1/_x), _x) = set([_x]) +| | +-mrv(1/_x, _x) = set([_x]) +| | +-mrv(_x, _x) = set([_x]) +| +-mrv_leadterm(exp(_x)*sin(exp(-_x)), _x, set([exp(_x)])) = (1, 0) +| +-rewrite(exp(_x)*sin(exp(-_x)), set([exp(_x)]), _x, _w) = (1/_w*sin(_w), -_x) +| +-sign(_x, _x) = 1 +| +-mrv_leadterm(1, _x) = (1, 0) ++-sign(0, _x) = 0 ++-limitinf(1, _x) = 1 + +And check manually which line is wrong. Then go to the source code and +debug this function to figure out the exact problem. + +""" +from functools import reduce + +from sympy.core import Basic, S, Mul, PoleError, expand_mul, evaluate +from sympy.core.cache import cacheit +from sympy.core.numbers import I, oo +from sympy.core.symbol import Dummy, Wild, Symbol +from sympy.core.traversal import bottom_up +from sympy.core.sorting import ordered + +from sympy.functions import log, exp, sign, sin +from sympy.series.order import Order +from sympy.utilities.exceptions import SymPyDeprecationWarning +from sympy.utilities.misc import debug_decorator as debug +from sympy.utilities.timeutils import timethis + +def mrv(e, x): + """ + Calculate the MRV set of the expression. + + Examples + ======== + + >>> mrv(log(x - log(x))/log(x), x) + {x} + + """ + + if e == x: + return {x} + if e.is_Mul or e.is_Add: + a, b = e.as_two_terms() + ans1 = mrv(a, x) + ans2 = mrv(b, x) + return mrv_max(mrv(a, x), mrv(b, x), x) + if e.is_Pow: + return mrv(e.base, x) + if e.is_Function: + return reduce(lambda a, b: mrv_max(a, b, x), (mrv(a, x) for a in e.args)) + raise NotImplementedError(f"Can't calculate the MRV of {e}.") + +def mrv_max(f, g, x): + """Compute the maximum of two MRV sets. + + Examples + ======== + + >>> mrv_max({log(x)}, {x**5}, x) + {x**5} + + """ + + if not f: + return g + if not g: + return f + if f & g: + return f | g + +def rewrite(e, x, w): + r""" + Rewrites the expression in terms of the MRV subexpression. + + Parameters + ========== + + e : Expr + an expression + x : Symbol + variable of the `e` + w : Symbol + The symbol which is going to be used for substitution in place + of the MRV in `x` subexpression. + + Returns + ======= + + The rewritten expression + + Examples + ======== + + >>> rewrite(exp(x)*log(x), x, y) + (log(x)/y, -x) + + """ + + Omega = mrv(e, x) + + if x in Omega: + # Moving up in the asymptotical scale: + with evaluate(False): + e = e.subs(x, exp(x)) + Omega = {s.subs(x, exp(x)) for s in Omega} + + Omega = list(ordered(Omega, keys=lambda a: -len(mrv(a, x)))) + + for g in Omega: + sig = signinf(g.exp, x) + if sig not in (1, -1): + raise NotImplementedError(f'Result depends on the sign of {sig}.') + + if sig == 1: + w = 1/w # if g goes to oo, substitute 1/w + + # Rewrite and substitute subexpressions in the Omega. + for a in Omega: + c = limitinf(a.exp/g.exp, x) + b = exp(a.exp - c*g.exp)*w**c # exponential must never be expanded here + with evaluate(False): + e = e.subs(a, b) + + return e + +@cacheit +def mrv_leadterm(e, x): + """ + Compute the leading term of the series. + + Returns + ======= + + tuple + The leading term `c_0 w^{e_0}` of the series of `e` in terms + of the most rapidly varying subexpression `w` in form of + the pair ``(c0, e0)`` of Expr. + + Examples + ======== + + >>> leadterm(1/exp(-x + exp(-x)) - exp(x), x) + (-1, 0) + + """ + + w = Dummy('w', real=True, positive=True) + e = rewrite(e, x, w) + return e.leadterm(w) + +@cacheit +def signinf(e, x): + r""" + Determine sign of the expression at the infinity. + + Returns + ======= + + {1, 0, -1} + One or minus one, if `e > 0` or `e < 0` for `x` sufficiently + large and zero if `e` is *constantly* zero for `x\to\infty`. + + """ + + if not e.has(x): + return sign(e) + if e == x or (e.is_Pow and signinf(e.base, x) == 1): + return S(1) + +@cacheit +def limitinf(e, x): + """ + Compute the limit of the expression at the infinity. + + Examples + ======== + + >>> limitinf(exp(x)*(exp(1/x - exp(-x)) - exp(1/x)), x) + -1 + + """ + + if not e.has(x): + return e + + c0, e0 = mrv_leadterm(e, x) + sig = signinf(e0, x) + if sig == 1: + return Integer(0) + if sig == -1: + return signinf(c0, x)*oo + if sig == 0: + return limitinf(c0, x) + raise NotImplementedError(f'Result depends on the sign of {sig}.') + + +def gruntz(e, z, z0, dir="+"): + """ + Compute the limit of e(z) at the point z0 using the Gruntz algorithm. + + Explanation + =========== + + ``z0`` can be any expression, including oo and -oo. + + For ``dir="+"`` (default) it calculates the limit from the right + (z->z0+) and for ``dir="-"`` the limit from the left (z->z0-). For infinite z0 + (oo or -oo), the dir argument does not matter. + + This algorithm is fully described in the module docstring in the gruntz.py + file. It relies heavily on the series expansion. Most frequently, gruntz() + is only used if the faster limit() function (which uses heuristics) fails. + """ + + if str(dir) == "-": + e0 = e.subs(z, z0 - 1/z) + elif str(dir) == "+": + e0 = e.subs(z, z0 + 1/z) + else: + raise NotImplementedError("dir must be '+' or '-'") + + r = limitinf(e0, z) + return r + +# tests +x = Symbol('x') +ans = gruntz(sin(x)/x, x, 0) +print(ans) \ No newline at end of file From c3a3d506bf19e6fb88e8c8de9230224d39b8cc85 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Fri, 17 May 2024 14:04:28 +0530 Subject: [PATCH 3/6] Step 3 --- integration_tests/gruntz_demo2.py | 27 ++- integration_tests/gruntz_demo3.py | 288 +++++++++++++++++++++++++++ src/libasr/pass/replace_symbolic.cpp | 30 +-- 3 files changed, 323 insertions(+), 22 deletions(-) create mode 100644 integration_tests/gruntz_demo3.py diff --git a/integration_tests/gruntz_demo2.py b/integration_tests/gruntz_demo2.py index cc392e5b0f..a9faead47d 100644 --- a/integration_tests/gruntz_demo2.py +++ b/integration_tests/gruntz_demo2.py @@ -118,7 +118,7 @@ """ from functools import reduce -from sympy.core import Basic, S, Mul, PoleError, expand_mul, evaluate +from sympy.core import Basic, S, Mul, PoleError, expand_mul, evaluate, Integer from sympy.core.cache import cacheit from sympy.core.numbers import I, oo from sympy.core.symbol import Dummy, Wild, Symbol @@ -145,14 +145,16 @@ def mrv(e, x): if e == x: return {x} - if e.is_Mul or e.is_Add: + elif e.is_Integer: + return {} + elif e.is_Mul or e.is_Add: a, b = e.as_two_terms() ans1 = mrv(a, x) ans2 = mrv(b, x) return mrv_max(mrv(a, x), mrv(b, x), x) - if e.is_Pow: + elif e.is_Pow: return mrv(e.base, x) - if e.is_Function: + elif e.is_Function: return reduce(lambda a, b: mrv_max(a, b, x), (mrv(a, x) for a in e.args)) raise NotImplementedError(f"Can't calculate the MRV of {e}.") @@ -225,7 +227,7 @@ def rewrite(e, x, w): c = limitinf(a.exp/g.exp, x) b = exp(a.exp - c*g.exp)*w**c # exponential must never be expanded here with evaluate(False): - e = e.subs(a, b) + e = e.xreplace({a: b}) return e @@ -330,5 +332,16 @@ def gruntz(e, z, z0, dir="+"): # tests x = Symbol('x') -ans = gruntz(sin(x)/x, x, 0) -print(ans) \ No newline at end of file +# Print the basic limit: +print(gruntz(sin(x)/x, x, 0)) + +# Test other cases +assert gruntz(sin(x)/x, x, 0) == 1 +assert gruntz(2*sin(x)/x, x, 0) == 2 +assert gruntz(sin(2*x)/x, x, 0) == 2 +assert gruntz(sin(x)**2/x, x, 0) == 0 +assert gruntz(sin(x)/x**2, x, 0) == oo +assert gruntz(sin(x)**2/x**2, x, 0) == 1 +assert gruntz(sin(sin(sin(x)))/sin(x), x, 0) == 1 +assert gruntz(2*log(x+1)/x, x, 0) == 2 +assert gruntz(sin((log(x+1)/x)*x)/x, x, 0) == 1 diff --git a/integration_tests/gruntz_demo3.py b/integration_tests/gruntz_demo3.py new file mode 100644 index 0000000000..2d58b96ab6 --- /dev/null +++ b/integration_tests/gruntz_demo3.py @@ -0,0 +1,288 @@ +from lpython import S, str +from sympy import Symbol, Pow, sin, oo, pi, E, Mul, Add, oo, log, exp, cos + +def mrv(e: S, x: S) -> list[S]: + """ + Calculate the MRV set of the expression. + + Examples + ======== + + >>> mrv(log(x - log(x))/log(x), x) + {x} + + """ + + if e.is_integer: + empty_list: list[S] = [] + return empty_list + if e == x: + list1: list[S] = [x] + return list1 + if e.func == log: + arg0: S = e.args[0] + list2: list[S] = mrv(arg0, x) + return list2 + if e.func == Mul or e.func == Add: + a: S = e.args[0] + b: S = e.args[1] + ans1: list[S] = mrv(a, x) + ans2: list[S] = mrv(b, x) + list3: list[S] = mrv_max(ans1, ans2, x) + return list3 + if e.func == Pow: + base: S = e.args[0] + list4: list[S] = mrv(base, x) + return list4 + if e.func == sin: + list5: list[S] = [x] + return list5 + # elif e.is_Function: + # return reduce(lambda a, b: mrv_max(a, b, x), (mrv(a, x) for a in e.args)) + raise NotImplementedError(f"Can't calculate the MRV of {e}.") + +def mrv_max(f: list[S], g: list[S], x: S) -> list[S]: + """Compute the maximum of two MRV sets. + + Examples + ======== + + >>> mrv_max({log(x)}, {x**5}, x) + {x**5} + + """ + + if len(f) == 0: + return g + elif len(g) == 0: + return f + # elif f & g: + # return f | g + else: + f1: S = f[0] + g1: S = g[0] + bool1: bool = f1 == x + bool2: bool = g1 == x + if bool1 and bool2: + l: list[S] = [x] + return l + +def rewrite(e: S, x: S, w: S) -> S: + """ + Rewrites the expression in terms of the MRV subexpression. + + Parameters + ========== + + e : Expr + an expression + x : Symbol + variable of the `e` + w : Symbol + The symbol which is going to be used for substitution in place + of the MRV in `x` subexpression. + + Returns + ======= + + The rewritten expression + + Examples + ======== + + >>> rewrite(exp(x)*log(x), x, y) + (log(x)/y, -x) + + """ + Omega: list[S] = mrv(e, x) + Omega1: S = Omega[0] + + if Omega1 == x: + newe: S = e.subs(x, S(1)/w) + return newe + +def sign(e: S) -> S: + """ + Returns the complex sign of an expression: + + Explanation + =========== + + If the expression is real the sign will be: + + * $1$ if expression is positive + * $0$ if expression is equal to zero + * $-1$ if expression is negative + """ + + if e.is_positive: + return S(1) + elif e == S(0): + return S(0) + else: + return S(-1) + +def signinf(e: S, x : S) -> S: + """ + Determine sign of the expression at the infinity. + + Returns + ======= + + {1, 0, -1} + One or minus one, if `e > 0` or `e < 0` for `x` sufficiently + large and zero if `e` is *constantly* zero for `x\to\infty`. + + """ + + if not e.has(x): + return sign(e) + if e == x: + return S(1) + if e.func == Pow: + base: S = e.args[0] + if signinf(base, x) == S(1): + return S(1) + +def leadterm(e: S, x: S) -> list[S]: + """ + Returns the leading term a*x**b as a list [a, b]. + """ + if e == sin(x)/x: + l1: list[S] = [S(1), S(0)] + return l1 + elif e == S(2)*sin(x)/x: + l2: list[S] = [S(2), S(0)] + return l2 + elif e == sin(S(2)*x)/x: + l3: list[S] = [S(2), S(0)] + return l3 + elif e == sin(x)**S(2)/x: + l4: list[S] = [S(1), S(1)] + return l4 + elif e == sin(x)/x**S(2): + l5: list[S] = [S(1), S(-1)] + return l5 + elif e == sin(x)**S(2)/x**S(2): + l6: list[S] = [S(1), S(0)] + return l6 + elif e == sin(sin(sin(x)))/sin(x): + l7: list[S] = [S(1), S(0)] + return l7 + elif e == S(2)*log(x+S(1))/x: + l8: list[S] = [S(2), S(0)] + return l8 + elif e == sin((log(x+S(1))/x)*x)/x: + l9: list[S] = [S(1), S(0)] + return l9 + raise NotImplementedError(f"Can't calculate the leadterm of {e}.") + +def mrv_leadterm(e: S, x: S) -> list[S]: + """ + Compute the leading term of the series. + + Returns + ======= + + tuple + The leading term `c_0 w^{e_0}` of the series of `e` in terms + of the most rapidly varying subexpression `w` in form of + the pair ``(c0, e0)`` of Expr. + + Examples + ======== + + >>> leadterm(1/exp(-x + exp(-x)) - exp(x), x) + (-1, 0) + + """ + + # w = Dummy('w', real=True, positive=True) + # e = rewrite(e, x, w) + # return e.leadterm(w) + w: S = Symbol('w') + newe: S = rewrite(e, x, w) + coeff_exp_list: list[S] = leadterm(newe, w) + + return coeff_exp_list + +def limitinf(e: S, x: S) -> S: + """ + Compute the limit of the expression at the infinity. + + Examples + ======== + + >>> limitinf(exp(x)*(exp(1/x - exp(-x)) - exp(1/x)), x) + -1 + + """ + + if not e.has(x): + return e + + coeff_exp_list: list[S] = mrv_leadterm(e, x) + c0: S = coeff_exp_list[0] + e0: S = coeff_exp_list[1] + sig: S = signinf(e0, x) + if sig == S(1): + return S(0) + if sig == S(-1): + return signinf(c0, x) * oo + if sig == S(0): + return limitinf(c0, x) + raise NotImplementedError(f'Result depends on the sign of {sig}.') + +def gruntz(e: S, z: S, z0: S, dir: str ="+") -> S: + """ + Compute the limit of e(z) at the point z0 using the Gruntz algorithm. + + Explanation + =========== + + ``z0`` can be any expression, including oo and -oo. + + For ``dir="+"`` (default) it calculates the limit from the right + (z->z0+) and for ``dir="-"`` the limit from the left (z->z0-). For infinite z0 + (oo or -oo), the dir argument does not matter. + + This algorithm is fully described in the module docstring in the gruntz.py + file. It relies heavily on the series expansion. Most frequently, gruntz() + is only used if the faster limit() function (which uses heuristics) fails. + """ + + e0: S + if str(dir) == "-": + e0 = e.subs(z, z0 - S(1)/z) + elif str(dir) == "+": + e0 = e.subs(z, z0 + S(1)/z) + else: + raise NotImplementedError("dir must be '+' or '-'") + + r: S = limitinf(e0, z) + return r + +# test +def test(): + x: S = Symbol('x') + print(gruntz(sin(x)/x, x, S(0), "+")) + print(gruntz(S(2)*sin(x)/x, x, S(0), "+")) + print(gruntz(sin(S(2)*x)/x, x, S(0), "+")) + print(gruntz(sin(x)**S(2)/x, x, S(0), "+")) + print(gruntz(sin(x)/x**S(2), x, S(0), "+")) + print(gruntz(sin(x)**S(2)/x**S(2), x, S(0), "+")) + print(gruntz(sin(sin(sin(x)))/sin(x), x, S(0), "+")) + print(gruntz(S(2)*log(x+S(1))/x, x, S(0), "+")) + print(gruntz(sin((log(x+S(1))/x)*x)/x, x, S(0), "+")) + + assert gruntz(sin(x)/x, x, S(0)) == S(1) + assert gruntz(S(2)*sin(x)/x, x, S(0)) == S(2) + assert gruntz(sin(S(2)*x)/x, x, S(0)) == S(2) + assert gruntz(sin(x)**S(2)/x, x, S(0)) == S(0) + assert gruntz(sin(x)/x**S(2), x, S(0)) == oo + assert gruntz(sin(x)**S(2)/x**S(2), x, S(0)) == S(1) + assert gruntz(sin(sin(sin(x)))/sin(x), x, S(0)) == S(1) + assert gruntz(S(2)*log(x+S(1))/x, x, S(0)) == S(2) + assert gruntz(sin((log(x+S(1))/x)*x)/x, x, S(0)) == S(1) + +test() \ No newline at end of file diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index a9382227fa..a53dfa6d6d 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -346,19 +346,19 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor func_body; - func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); + // if (!symbolic_vars_to_free.empty()) { + // Vec func_body; + // func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); - for (ASR::symbol_t* symbol : symbolic_vars_to_free) { - func_body.push_back(al, basic_free_stack(x.base.base.loc, - ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)))); - } + // for (ASR::symbol_t* symbol : symbolic_vars_to_free) { + // func_body.push_back(al, basic_free_stack(x.base.base.loc, + // ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)))); + // } - xx.n_body = func_body.size(); - xx.m_body = func_body.p; - symbolic_vars_to_free.clear(); - } + // xx.n_body = func_body.size(); + // xx.m_body = func_body.p; + // symbolic_vars_to_free.clear(); + // } SetChar function_dependencies; function_dependencies.from_pointer_n_copy(al, xx.m_dependencies, xx.n_dependencies); @@ -1113,10 +1113,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor Date: Mon, 8 Jul 2024 15:53:57 +0530 Subject: [PATCH 4/6] Commented out freeing variables --- integration_tests/gruntz_demo3.py | 87 ++++++++++++---------------- src/libasr/pass/replace_symbolic.cpp | 30 +++++----- 2 files changed, 51 insertions(+), 66 deletions(-) diff --git a/integration_tests/gruntz_demo3.py b/integration_tests/gruntz_demo3.py index 2d58b96ab6..e53e291bd5 100644 --- a/integration_tests/gruntz_demo3.py +++ b/integration_tests/gruntz_demo3.py @@ -1,5 +1,5 @@ from lpython import S, str -from sympy import Symbol, Pow, sin, oo, pi, E, Mul, Add, oo, log, exp, cos +from sympy import Symbol, Pow, sin, oo, pi, E, Mul, Add, oo, log, exp, sign def mrv(e: S, x: S) -> list[S]: """ @@ -101,27 +101,6 @@ def rewrite(e: S, x: S, w: S) -> S: newe: S = e.subs(x, S(1)/w) return newe -def sign(e: S) -> S: - """ - Returns the complex sign of an expression: - - Explanation - =========== - - If the expression is real the sign will be: - - * $1$ if expression is positive - * $0$ if expression is equal to zero - * $-1$ if expression is negative - """ - - if e.is_positive: - return S(1) - elif e == S(0): - return S(0) - else: - return S(-1) - def signinf(e: S, x : S) -> S: """ Determine sign of the expression at the infinity. @@ -148,33 +127,39 @@ def leadterm(e: S, x: S) -> list[S]: """ Returns the leading term a*x**b as a list [a, b]. """ - if e == sin(x)/x: - l1: list[S] = [S(1), S(0)] + term1: S = sin(x)/x + term2: S = S(2)*sin(x)/x + term3: S = sin(S(2)*x)/x + term4: S = sin(x)**S(2)/x + term5: S = sin(x)/x**S(2) + term6: S = sin(x)**S(2)/x**S(2) + term7: S = sin(sin(sin(x)))/sin(x) + term8: S = S(2)*log(x+S(1))/x + term9: S = sin((log(x+S(1))/x)*x)/x + + l1: list[S] = [S(1), S(0)] + l2: list[S] = [S(2), S(0)] + l3: list[S] = [S(1), S(1)] + l4: list[S] = [S(1), S(-1)] + + if e == term1: return l1 - elif e == S(2)*sin(x)/x: - l2: list[S] = [S(2), S(0)] + elif e == term2: + return l2 + elif e == term3: return l2 - elif e == sin(S(2)*x)/x: - l3: list[S] = [S(2), S(0)] + elif e == term4: return l3 - elif e == sin(x)**S(2)/x: - l4: list[S] = [S(1), S(1)] + elif e == term5: return l4 - elif e == sin(x)/x**S(2): - l5: list[S] = [S(1), S(-1)] - return l5 - elif e == sin(x)**S(2)/x**S(2): - l6: list[S] = [S(1), S(0)] - return l6 - elif e == sin(sin(sin(x)))/sin(x): - l7: list[S] = [S(1), S(0)] - return l7 - elif e == S(2)*log(x+S(1))/x: - l8: list[S] = [S(2), S(0)] - return l8 - elif e == sin((log(x+S(1))/x)*x)/x: - l9: list[S] = [S(1), S(0)] - return l9 + elif e == term6: + return l1 + elif e == term7: + return l1 + elif e == term8: + return l2 + elif e == term9: + return l1 raise NotImplementedError(f"Can't calculate the leadterm of {e}.") def mrv_leadterm(e: S, x: S) -> list[S]: @@ -196,14 +181,12 @@ def mrv_leadterm(e: S, x: S) -> list[S]: (-1, 0) """ - # w = Dummy('w', real=True, positive=True) # e = rewrite(e, x, w) # return e.leadterm(w) w: S = Symbol('w') newe: S = rewrite(e, x, w) coeff_exp_list: list[S] = leadterm(newe, w) - return coeff_exp_list def limitinf(e: S, x: S) -> S: @@ -217,7 +200,6 @@ def limitinf(e: S, x: S) -> S: -1 """ - if not e.has(x): return e @@ -225,10 +207,11 @@ def limitinf(e: S, x: S) -> S: c0: S = coeff_exp_list[0] e0: S = coeff_exp_list[1] sig: S = signinf(e0, x) + case_2: S = signinf(c0, x) * oo if sig == S(1): return S(0) if sig == S(-1): - return signinf(c0, x) * oo + return case_2 if sig == S(0): return limitinf(c0, x) raise NotImplementedError(f'Result depends on the sign of {sig}.') @@ -252,10 +235,12 @@ def gruntz(e: S, z: S, z0: S, dir: str ="+") -> S: """ e0: S + sub_neg: S = z0 - S(1)/z + sub_pos: S = z0 + S(1)/z if str(dir) == "-": - e0 = e.subs(z, z0 - S(1)/z) + e0 = e.subs(z, sub_neg) elif str(dir) == "+": - e0 = e.subs(z, z0 + S(1)/z) + e0 = e.subs(z, sub_pos) else: raise NotImplementedError("dir must be '+' or '-'") diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index a53dfa6d6d..a9382227fa 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -346,19 +346,19 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor func_body; - // func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); + if (!symbolic_vars_to_free.empty()) { + Vec func_body; + func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); - // for (ASR::symbol_t* symbol : symbolic_vars_to_free) { - // func_body.push_back(al, basic_free_stack(x.base.base.loc, - // ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)))); - // } + for (ASR::symbol_t* symbol : symbolic_vars_to_free) { + func_body.push_back(al, basic_free_stack(x.base.base.loc, + ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol)))); + } - // xx.n_body = func_body.size(); - // xx.m_body = func_body.p; - // symbolic_vars_to_free.clear(); - // } + xx.n_body = func_body.size(); + xx.m_body = func_body.p; + symbolic_vars_to_free.clear(); + } SetChar function_dependencies; function_dependencies.from_pointer_n_copy(al, xx.m_dependencies, xx.n_dependencies); @@ -1113,10 +1113,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor Date: Mon, 8 Jul 2024 15:58:15 +0530 Subject: [PATCH 5/6] minor improvements --- integration_tests/gruntz_demo3.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/integration_tests/gruntz_demo3.py b/integration_tests/gruntz_demo3.py index e53e291bd5..a058f1e156 100644 --- a/integration_tests/gruntz_demo3.py +++ b/integration_tests/gruntz_demo3.py @@ -127,38 +127,29 @@ def leadterm(e: S, x: S) -> list[S]: """ Returns the leading term a*x**b as a list [a, b]. """ - term1: S = sin(x)/x - term2: S = S(2)*sin(x)/x - term3: S = sin(S(2)*x)/x - term4: S = sin(x)**S(2)/x - term5: S = sin(x)/x**S(2) - term6: S = sin(x)**S(2)/x**S(2) - term7: S = sin(sin(sin(x)))/sin(x) - term8: S = S(2)*log(x+S(1))/x - term9: S = sin((log(x+S(1))/x)*x)/x l1: list[S] = [S(1), S(0)] l2: list[S] = [S(2), S(0)] l3: list[S] = [S(1), S(1)] l4: list[S] = [S(1), S(-1)] - if e == term1: + if e == sin(x)/x: return l1 - elif e == term2: + elif e == S(2)*sin(x)/x: return l2 - elif e == term3: + elif e == sin(S(2)*x)/x: return l2 - elif e == term4: + elif e == sin(x)**S(2)/x: return l3 - elif e == term5: + elif e == sin(x)/x**S(2): return l4 - elif e == term6: + elif e == sin(x)**S(2)/x**S(2): return l1 - elif e == term7: + elif e == sin(sin(sin(x)))/sin(x): return l1 - elif e == term8: + elif e == S(2)*log(x+S(1))/x: return l2 - elif e == term9: + elif e == sin((log(x+S(1))/x)*x)/x: return l1 raise NotImplementedError(f"Can't calculate the leadterm of {e}.") From 16ea9483db67b304e45338eeadb7ebb6f7139792 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Mon, 8 Jul 2024 19:24:05 +0530 Subject: [PATCH 6/6] Added gruntz_demo3 as a test --- integration_tests/CMakeLists.txt | 1 + integration_tests/gruntz_demo3.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index caad1f5c96..4632194e50 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -746,6 +746,7 @@ RUN(NAME symbolics_15 LABELS c_sym llvm_sym llvm_jit NOFAST EXTRA_ARGS -- RUN(NAME symbolics_16 LABELS cpython_sym c_sym llvm_sym llvm_jit NOFAST EXTRA_ARGS --enable-symengine) RUN(NAME symbolics_17 LABELS cpython_sym c_sym llvm_sym llvm_jit NOFAST EXTRA_ARGS --enable-symengine) RUN(NAME symbolics_18 LABELS cpython_sym c_sym llvm_sym llvm_jit NOFAST EXTRA_ARGS --enable-symengine) +RUN(NAME gruntz_demo3 LABELS cpython_sym c_sym llvm_sym llvm_jit NOFAST EXTRA_ARGS --enable-symengine) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/gruntz_demo3.py b/integration_tests/gruntz_demo3.py index a058f1e156..fc6da2a174 100644 --- a/integration_tests/gruntz_demo3.py +++ b/integration_tests/gruntz_demo3.py @@ -1,4 +1,4 @@ -from lpython import S, str +from lpython import S from sympy import Symbol, Pow, sin, oo, pi, E, Mul, Add, oo, log, exp, sign def mrv(e: S, x: S) -> list[S]: