Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Avoid overwriting local contexts with retry decorator #479

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Avoid overwriting local contexts when applying the retry decorator.
10 changes: 8 additions & 2 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,19 @@ def wraps(self, f: WrappedFn) -> WrappedFn:
f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any:
return self(f, *args, **kw)
# Always create a copy to prevent overwriting the local contexts when
# calling the same wrapped functions multiple times in the same stack
copy = self.copy()
wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
return copy(f, *args, **kw)

def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn:
return self.copy(*args, **kwargs).wraps(f)

wrapped_f.retry = self # type: ignore[attr-defined]
# Preserve attributes
wrapped_f.retry = wrapped_f # type: ignore[attr-defined]
wrapped_f.retry_with = retry_with # type: ignore[attr-defined]
wrapped_f.statistics = {} # type: ignore[attr-defined]

return wrapped_f # type: ignore[return-value]

Expand Down
13 changes: 9 additions & 4 deletions tenacity/asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,23 @@ async def __anext__(self) -> AttemptManager:
raise StopAsyncIteration

def wraps(self, fn: WrappedFn) -> WrappedFn:
fn = super().wraps(fn)
wrapped = super().wraps(fn)
# Ensure wrapper is recognized as a coroutine function.

@functools.wraps(
fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
)
async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
return await fn(*args, **kwargs)
# Always create a copy to prevent overwriting the local contexts when
# calling the same wrapped functions multiple times in the same stack
copy = self.copy()
async_wrapped.statistics = copy.statistics # type: ignore[attr-defined]
return await copy(fn, *args, **kwargs)

# Preserve attributes
async_wrapped.retry = fn.retry # type: ignore[attr-defined]
async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined]
async_wrapped.retry = async_wrapped # type: ignore[attr-defined]
async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined]
async_wrapped.statistics = {} # type: ignore[attr-defined]

return async_wrapped # type: ignore[return-value]

Expand Down
118 changes: 118 additions & 0 deletions tests/test_issue_478.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import asyncio
import typing
import unittest

from functools import wraps

from tenacity import RetryCallState, retry


def asynctest(
callable_: typing.Callable[..., typing.Any],
) -> typing.Callable[..., typing.Any]:
@wraps(callable_)
def wrapper(*a: typing.Any, **kw: typing.Any) -> typing.Any:
loop = asyncio.get_event_loop()
return loop.run_until_complete(callable_(*a, **kw))

return wrapper


MAX_RETRY_FIX_ATTEMPTS = 2


class TestIssue478(unittest.TestCase):
def test_issue(self) -> None:
results = []

def do_retry(retry_state: RetryCallState) -> bool:
outcome = retry_state.outcome
assert outcome
ex = outcome.exception()
_subject_: str = retry_state.args[0]

if _subject_ == "Fix": # no retry on fix failure
return False

if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
return False

if ex:
do_fix_work()
return True

return False

@retry(reraise=True, retry=do_retry)
def _do_work(subject: str) -> None:
if subject == "Error":
results.append(f"{subject} is not working")
raise Exception(f"{subject} is not working")
results.append(f"{subject} is working")

def do_any_work(subject: str) -> None:
_do_work(subject)

def do_fix_work() -> None:
_do_work("Fix")

try:
do_any_work("Error")
except Exception as exc:
assert str(exc) == "Error is not working"
else:
assert False, "No exception caught"

assert results == [
"Error is not working",
"Fix is working",
"Error is not working",
]

@asynctest
async def test_async(self) -> None:
results = []

async def do_retry(retry_state: RetryCallState) -> bool:
outcome = retry_state.outcome
assert outcome
ex = outcome.exception()
_subject_: str = retry_state.args[0]

if _subject_ == "Fix": # no retry on fix failure
return False

if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
return False

if ex:
await do_fix_work()
return True

return False

@retry(reraise=True, retry=do_retry)
async def _do_work(subject: str) -> None:
if subject == "Error":
results.append(f"{subject} is not working")
raise Exception(f"{subject} is not working")
results.append(f"{subject} is working")

async def do_any_work(subject: str) -> None:
await _do_work(subject)

async def do_fix_work() -> None:
await _do_work("Fix")

try:
await do_any_work("Error")
except Exception as exc:
assert str(exc) == "Error is not working"
else:
assert False, "No exception caught"

assert results == [
"Error is not working",
"Fix is working",
"Error is not working",
]