Skip to content

Commit

Permalink
Better control over awaited
Browse files Browse the repository at this point in the history
Allow to skip a given number of awaits
and wait for the next await.

Refs #64
  • Loading branch information
Kentzo committed Jan 9, 2018
1 parent 60aa5a9 commit b785776
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 19 deletions.
57 changes: 47 additions & 10 deletions asynctest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,49 @@ class MagicMock(AsyncMagicMixin, unittest.mock.MagicMock,
"""


class _AwaitEvent(asyncio.Event):
"""
A mix between asyncio.Event and bool.
"""
class _AwaitEvent:
def __init__(self, mock):
self._mock = mock
self._condition = asyncio.Condition()

@asyncio.coroutine
def wait(self, skip=0):
def predicate(mock):
return mock.await_count > skip

return (yield from self.wait_for(predicate))

@asyncio.coroutine
def wait_next(self, skip=0):
await_count = self._mock.await_count

def predicate(mock):
return mock.await_count > await_count + skip

return (yield from self.wait_for(predicate))

@asyncio.coroutine
def wait_for(self, predicate):
try:
yield from self._condition.acquire()

def _predicate():
return predicate(self._mock)

return (yield from self._condition.wait_for(_predicate))
finally:
self._condition.release()

@asyncio.coroutine
def notify(self):
try:
yield from self._condition.acquire()
self._condition.notify_all()
finally:
self._condition.release()

def __bool__(self):
return self.is_set()
return self._mock.await_count != 0


class CoroutineMock(Mock):
Expand Down Expand Up @@ -449,7 +486,7 @@ def __init__(self, *args, **kwargs):
# It is set through __dict__ because when spec_set is True, this
# attribute is likely undefined.
self.__dict__['_is_coroutine'] = _is_coroutine
self.__dict__['_mock_awaited'] = _AwaitEvent()
self.__dict__['_mock_awaited'] = _AwaitEvent(self)
self.__dict__['_mock_await_count'] = 0

def _mock_call(_mock_self, *args, **kwargs):
Expand All @@ -462,16 +499,16 @@ def proxy():
try:
return (yield from result)
finally:
yield from _mock_self.awaited.notify()
_mock_self.await_count += 1
_mock_self.awaited.set()
else:
@asyncio.coroutine
def proxy():
try:
return result
finally:
yield from _mock_self.awaited.notify()
_mock_self.await_count += 1
_mock_self.awaited.set()

return proxy()
except StopIteration as e:
Expand Down Expand Up @@ -512,7 +549,7 @@ def reset_mock(self, *args, **kwargs):
See :func:`unittest.mock.Mock.reset_mock()`
"""
super().reset_mock(*args, **kwargs)
self.awaited = _AwaitEvent()
self.awaited = _AwaitEvent(self)
self.await_count = 0


Expand Down Expand Up @@ -581,7 +618,7 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# _set_signature returns the result of the CoroutineMock itself,
# which is a Coroutine (as defined in CoroutineMock._mock_call)
mock._is_coroutine = _is_coroutine
mock.awaited = _AwaitEvent()
mock.awaited = _AwaitEvent(mock)
mock.await_count = 0

def assert_awaited(*args, **kwargs):
Expand Down
40 changes: 35 additions & 5 deletions test/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import asynctest

from .utils import run_coroutine
from .utils import run_coroutine, replace_loop

if sys.version_info >= (3, 5):
from . import test_mock_await as _using_await
Expand Down Expand Up @@ -308,12 +308,10 @@ def test_awaited_CoroutineMock_sets_awaited(self):
mock = asynctest.mock.CoroutineMock()
run_coroutine(mock())
mock.assert_awaited()
self.assertTrue(mock.awaited.is_set())
self.assertTrue(mock.awaited)

mock.reset_mock()
mock.assert_not_awaited()
self.assertFalse(mock.awaited.is_set())
self.assertFalse(mock.awaited)

@asyncio.coroutine
Expand Down Expand Up @@ -351,16 +349,48 @@ def test_awaited_from_autospec_mock(self):
mock = asynctest.mock.create_autospec(Test)
mock.a_coroutine.assert_not_awaited()
self.assertFalse(mock.a_coroutine.awaited)
self.assertFalse(mock.a_coroutine.awaited.is_set())
self.assertEqual(0, mock.a_coroutine.await_count)

run_coroutine(mock.a_coroutine())

mock.a_coroutine.assert_awaited()
self.assertTrue(mock.a_coroutine.awaited)
self.assertTrue(mock.a_coroutine.awaited.is_set())
self.assertEqual(1, mock.a_coroutine.await_count)

def test_awaited_wait(self):
loop = asyncio.new_event_loop()
with replace_loop(loop):
mock = asynctest.mock.CoroutineMock()
t = asyncio.ensure_future(mock.awaited.wait())
run_coroutine(mock(), loop)
run_coroutine(t, loop)

mock.reset_mock()
t = asyncio.ensure_future(mock.awaited.wait(skip=1))
run_coroutine(mock(), loop)
self.assertFalse(t.done())
run_coroutine(mock(), loop)
run_coroutine(t, loop)

def test_awaited_wait_next(self):
loop = asyncio.new_event_loop()
with replace_loop(loop):
mock = asynctest.mock.CoroutineMock()
run_coroutine(mock(), loop)
t = asyncio.ensure_future(mock.awaited.wait_next())
run_coroutine(asyncio.sleep(0.01), loop)
self.assertFalse(t.done())
run_coroutine(mock(), loop)
run_coroutine(t, loop)

mock.reset_mock()
run_coroutine(mock(), loop)
t = asyncio.ensure_future(mock.awaited.wait_next(skip=1))
run_coroutine(mock(), loop)
self.assertFalse(t.done())
run_coroutine(mock(), loop)
run_coroutine(t, loop)


class TestMockInheritanceModel(unittest.TestCase):
to_test = {
Expand Down
28 changes: 24 additions & 4 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
# coding: utf-8
import asyncio
import contextlib


def run_coroutine(coroutine):
loop = asyncio.new_event_loop()
def run_coroutine(coroutine, loop=None):
if loop is None:
loop = asyncio.new_event_loop()
close = True
else:
close = False

with replace_loop(loop, close):
return loop.run_until_complete(coroutine)


@contextlib.contextmanager
def replace_loop(loop, close=True):
try:
current_loop = asyncio.get_event_loop()
except:
current_loop = None

asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coroutine)
yield
finally:
loop.close()
if close:
loop.close()

asyncio.set_event_loop(current_loop)

0 comments on commit b785776

Please sign in to comment.