Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mock.py #164

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 17 additions & 25 deletions asynctest/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,20 +353,18 @@ def _run_test_method(self, method):
if asyncio.iscoroutine(result):
self.loop.run_until_complete(result)

@asyncio.coroutine
def doCleanups(self):
async def doCleanups(self):
"""
Execute all cleanup functions. Normally called for you after tearDown.
"""
outcome = self._outcome or unittest.mock._Outcome()
outcome = self._outcome or mock._Outcome()
while self._cleanups:
function, args, kwargs = self._cleanups.pop()
with outcome.testPartExecutor(self):
if asyncio.iscoroutinefunction(function):
yield from function(*args, **kwargs)
await function(*args, **kwargs)
else:
function(*args, **kwargs)

return outcome.success

def addCleanup(self, function, *args, **kwargs):
Expand All @@ -377,8 +375,7 @@ def addCleanup(self, function, *args, **kwargs):
"""
return super().addCleanup(function, *args, **kwargs)

@asyncio.coroutine
def assertAsyncRaises(self, exception, awaitable):
async def assertAsyncRaises(self, exception, awaitable):
"""
Test that an exception of type ``exception`` is raised when an
exception is raised when awaiting ``awaitable``, a future or coroutine.
Expand All @@ -391,40 +388,37 @@ def assertAsyncRaises(self, exception, awaitable):
:see: :meth:`unittest.TestCase.assertRaises()`
"""
with self.assertRaises(exception):
return (yield from awaitable)
await awaitable

@asyncio.coroutine
def assertAsyncRaisesRegex(self, exception, regex, awaitable):
async def assertAsyncRaisesRegex(self, exception, regex, awaitable):
"""
Like :meth:`assertAsyncRaises()` but also tests that ``regex`` matches
on the string representation of the raised exception.

:see: :meth:`unittest.TestCase.assertRaisesRegex()`
"""
with self.assertRaisesRegex(exception, regex):
return (yield from awaitable)
await awaitable

@asyncio.coroutine
def assertAsyncWarns(self, warning, awaitable):
async def assertAsyncWarns(self, warning, awaitable):
"""
Test that a warning is triggered when awaiting ``awaitable``, a future
or a coroutine.

:see: :meth:`unittest.TestCase.assertWarns()`
"""
with self.assertWarns(warning):
return (yield from awaitable)
await awaitable

@asyncio.coroutine
def assertAsyncWarnsRegex(self, warning, regex, awaitable):
async def assertAsyncWarnsRegex(self, warning, regex, awaitable):
"""
Like :meth:`assertAsyncWarns()` but also tests that ``regex`` matches
on the message of the triggered warning.

:see: :meth:`unittest.TestCase.assertWarnsRegex()`
"""
with self.assertWarnsRegex(warning, regex):
return (yield from awaitable)
await awaitable


class FunctionTestCase(TestCase, unittest.FunctionTestCase):
Expand All @@ -446,8 +440,7 @@ def _init_loop(self):
self.loop.time = functools.wraps(self.loop.time)(lambda: self._time)
self._time = 0

@asyncio.coroutine
def advance(self, seconds):
async def advance(self, seconds):
"""
Fast forward time by a number of ``seconds``.

Expand All @@ -468,7 +461,7 @@ def advance(self, seconds):
raise ValueError(
'Cannot go back in time ({} seconds)'.format(seconds))

yield from self._drain_loop()
await self._drain_loop()

target_time = self._time + seconds
while True:
Expand All @@ -477,26 +470,25 @@ def advance(self, seconds):
break

self._time = next_time
yield from self._drain_loop()
await self._drain_loop()

self._time = target_time
yield from self._drain_loop()
await self._drain_loop()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
await self._drain_loop()
await self._drain_loop()


def _next_scheduled(self):
try:
return self.loop._scheduled[0]._when
except IndexError:
return None

@asyncio.coroutine
def _drain_loop(self):
async def _drain_loop(self):
while True:
next_time = self._next_scheduled()
if not self.loop._ready and (next_time is None or
next_time > self._time):
break

yield from asyncio.sleep(0)
await asyncio.sleep(0)
self.loop._TestCase_asynctest_ran = True


Expand Down
5 changes: 2 additions & 3 deletions asynctest/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import asyncio


@asyncio.coroutine
def exhaust_callbacks(loop):
async def exhaust_callbacks(loop):
"""
Run the loop until all ready callbacks are executed.

Expand All @@ -21,4 +20,4 @@ def exhaust_callbacks(loop):
:param loop: event loop
"""
while loop._ready:
yield from asyncio.sleep(0, loop=loop)
await asyncio.sleep(0)
22 changes: 9 additions & 13 deletions asynctest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,7 @@ def __init__(self, mock):
self._mock = mock
self._condition = None

@asyncio.coroutine
def wait(self, skip=0):
async def wait(self, skip=0):
"""
Wait for await.

Expand All @@ -442,10 +441,9 @@ def wait(self, skip=0):
def predicate(mock):
return mock.await_count > skip

return (yield from self.wait_for(predicate))
return await self.wait_for(predicate)

@asyncio.coroutine
def wait_next(self, skip=0):
async def wait_next(self, skip=0):
"""
Wait for the next await.

Expand All @@ -462,10 +460,9 @@ def wait_next(self, skip=0):
def predicate(mock):
return mock.await_count > await_count + skip

return (yield from self.wait_for(predicate))
return await self.wait_for(predicate)

@asyncio.coroutine
def wait_for(self, predicate):
async def wait_for(self, predicate):
"""
Wait for a given predicate to become True.

Expand All @@ -476,21 +473,20 @@ def wait_for(self, predicate):
condition = self._get_condition()

try:
yield from condition.acquire()
await condition.acquire()

def _predicate():
return predicate(self._mock)

return (yield from condition.wait_for(_predicate))
return await condition.wait_for(_predicate)
finally:
condition.release()

@asyncio.coroutine
def _notify(self):
async def _notify(self):
condition = self._get_condition()

try:
yield from condition.acquire()
await condition.acquire()
condition.notify_all()
finally:
condition.release()
Expand Down