Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the awaited attribute to CoroutineMock. #67

Merged
merged 11 commits into from
Jan 30, 2018
187 changes: 177 additions & 10 deletions asynctest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def __aiter__():
_is_coroutine = True


try:
# Python 3.5+
_isawaitable = inspect.isawaitable
except AttributeError:
_isawaitable = asyncio.iscoroutine


def _raise(exception):
raise exception

Expand Down Expand Up @@ -141,7 +148,7 @@ def _get_child_mock(self, *args, **kwargs):

_type = type(self)

if (issubclass(_type, MagicMock) and _new_name in async_magic_coroutines):
if issubclass(_type, MagicMock) and _new_name in async_magic_coroutines:
klass = CoroutineMock
elif issubclass(_type, CoroutineMock):
klass = MagicMock
Expand Down Expand Up @@ -180,7 +187,7 @@ def __new__(meta, name, base, namespace):
'_asynctest_get_is_coroutine': _get_is_coroutine,
'_asynctest_set_is_coroutine': _set_is_coroutine,
'is_coroutine': property(_get_is_coroutine, _set_is_coroutine,
"True if the object mocked is a coroutine"),
doc="True if the object mocked is a coroutine"),
'_is_coroutine': property(_get_is_coroutine),
})

Expand Down Expand Up @@ -218,8 +225,6 @@ def _mock_set_async_magics(self):

if getattr(self, "_mock_methods", None) is not None:
these_magics = _async_magics.intersection(self._mock_methods)

remove_magics = set()
remove_magics = _async_magics - these_magics

for entry in remove_magics:
Expand Down Expand Up @@ -376,6 +381,95 @@ class MagicMock(AsyncMagicMixin, unittest.mock.MagicMock,
"""


class _AwaitEvent:
def __init__(self, mock):
self._mock = mock
self._condition = None

@asyncio.coroutine
def wait(self, skip=0):
"""
Wait for await.

:param skip: How many awaits will be skipped.
As a result, the mock should be awaited at least
``skip + 1`` times.
"""
def predicate(mock):
return mock.await_count > skip

return (yield from self.wait_for(predicate))

@asyncio.coroutine
def wait_next(self, skip=0):
"""
Wait for the next await.

Unlike :meth:`wait` that counts any await, mock has to be awaited once more,
disregarding to the current :attr:`asynctest.CoroutineMock.await_count`.

:param skip: How many awaits will be skipped.
As a result, the mock should be awaited at least
``skip + 1`` more times.
"""
await_count = self._mock.await_count

def predicate(mock):
return mock.await_count > await_count + skip
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for wait(), to be consistent I'd call the argument min_next_wait_count, set 1 as default value, and use >= in the predicate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the skip name more to the point: wait for await skipping # awaits.

Please take a look at the rephrased docstrings. If they are still hard to comprehend, I'll change the name.


return (yield from self.wait_for(predicate))

@asyncio.coroutine
def wait_for(self, predicate):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_AwaitEvent.wait_for() and AwaitEvent.notify() should be private (prefixed with "")

Copy link
Contributor Author

@Kentzo Kentzo Jan 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think github's parser ate the prefix, what did you mean _?

I wanted to keep wait_for public for a user to introduce more complex conditions (e.g. mock is awaited after being called with certain arguments).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interresting use case. I didn't think of something like this.

I think that we can add this example in the documentation, and tell that the predicate is only checked when the mock is awaited. I don't want users to have the feeling that it can magically awake after a random event.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I was thinking about this:

def predicate(mock):
    return mock.await_count > 0 and m.call_args == (("foo", ), {})

t1 = asyncio.ensure_future(mock("foo"))
t2 = mock("bar")  # not yet scheduled
mock.awaited.wait_for(predicate)  # will hang forever: mock.call_args is now based on the last call.

The problem is that the predicate is based on the mock (ie coroutine function), not the coroutine that will be awaited itself.

I'm affraid that this will be confusing as it's not possible to wait_for() a given instance of the coroutine function to be awaited. Maybe we should keep this private and let users write their own logic when they want to do complex things (one can use the wraps argument to customize th behavior of the mock).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use call_args_list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won’t be 100% accurate (you can await in different order) but could be sufficient for some tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other hand for those tests skip would be sufficient as well. Perhaps we should add await_args / await_args_list?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the use case for checking which exact call has been awaited makes sense, but I'm not conviced with generic/public wait_for() method yet.

I think that we should merge this PR with a private _wait_for() as it's already very useful like this. Then we can work on a new PR to implement wait_call() and wait_next_call() or add optional parameters matching in wait() and wait_next().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I was working on that, ability to await a specific call was only one of use cases I thought of. Another was an ability to await an external consition like setting mock’s attribute to a specific value. Or more general: any modification of mock’s state related to await. I also thought of ability to await external conditions, but couldn’t think of a practical example.

I would opt for having wait_for public as it seems in line with the naming of async primitives. People who would confuse it would probably have problems with wait and wait_next as well. This an advanced method, but it is possible to mark it as such with docs.

What do you say if we keep it public via naming, but exclude from the documentation for now?

Copy link
Contributor Author

@Kentzo Kentzo Jan 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check GreatFruitOmsk@3839b15, it introduces an independent list of "await args" and among other addresses the issue in your example for wait_for.

"""
Wait for a given predicate to become True.

:param predicate: A callable that receives mock which result
will be interpreted as a boolean value.
The final predicate value is the return value.
"""
c = self._get_condition()

try:
yield from c.acquire()

def _predicate():
return predicate(self._mock)

return (yield from c.wait_for(_predicate))
finally:
c.release()

@asyncio.coroutine
def _notify(self):
c = self._get_condition()

try:
yield from c.acquire()
c.notify_all()
finally:
c.release()

def _get_condition(self):
"""
Creation of condition is delayed, to minimize the change of using the wrong loop.

A user may create a mock with _AwaitEvent before selecting the execution loop.
Requiring a user to delay creation is error-prone and inflexible. Instead, condition
is created when user actually starts to use the mock.
"""
# No synchronization is needed:
# - asyncio is thread unsafe
# - there are no awaits here, method will be executed without switching asyncio context.
if self._condition is None:
self._condition = asyncio.Condition()

return self._condition

def __bool__(self):
return self._mock.await_count != 0


class CoroutineMock(Mock):
"""
Enhance :class:`~asynctest.mock.Mock` with features allowing to mock
Expand Down Expand Up @@ -414,6 +508,17 @@ class CoroutineMock(Mock):
:class:`unittest.mock.Mock` object: the wrapped object may have methods
defined as coroutine functions.
"""
#: Property which is set when the mock is awaited. Its ``wait``,
#: ``wait_next`` and ``wait_for`` coroutine methods can be used
#: to synchronize execution.
#:
#: .. versionadded:: 0.12
awaited = unittest.mock._delegating_property('awaited')
#: Number of times the mock has been awaited (or "yielded from").
#:
#: .. versionadded:: 0.12
await_count = unittest.mock._delegating_property('await_count')

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand All @@ -422,15 +527,31 @@ def __init__(self, *args, **kwargs):
# It is set through __dict__ because when spec_set is True, this
# attribute is likely undefined.
self.__dict__['_is_coroutine'] = _is_coroutine
self.__dict__['_mock_awaited'] = _AwaitEvent(self)
self.__dict__['_mock_await_count'] = 0

def _mock_call(_mock_self, *args, **kwargs):
try:
result = super()._mock_call(*args, **kwargs)

if asyncio.iscoroutine(result):
return result
if _isawaitable(result):
@asyncio.coroutine
def proxy():
try:
return (yield from result)
finally:
_mock_self.await_count += 1
yield from _mock_self.awaited._notify()
else:
return asyncio.coroutine(lambda *a, **kw: result)()
@asyncio.coroutine
def proxy():
try:
return result
finally:
_mock_self.await_count += 1
yield from _mock_self.awaited._notify()

return proxy()
except StopIteration as e:
side_effect = _mock_self.side_effect
if side_effect is not None and not callable(side_effect):
Expand All @@ -440,6 +561,38 @@ def _mock_call(_mock_self, *args, **kwargs):
except BaseException as e:
return asyncio.coroutine(_raise)(e)

def assert_awaited(_mock_self):
"""
Assert that the mock was awaited at least once.

.. versionadded:: 0.12
"""
self = _mock_self
if self.await_count == 0:
msg = ("Expected '%s' to have been awaited." %
self._mock_name or 'mock')
raise AssertionError(msg)

def assert_not_awaited(_mock_self):
"""
Assert that the mock was never awaited.

.. versionadded:: 0.12
"""
self = _mock_self
if self.await_count != 0:
msg = ("Expected '%s' to not have been awaited. Awaited %s times." %
(self._mock_name or 'mock', self.await_count))
raise AssertionError(msg)

def reset_mock(self, *args, **kwargs):
"""
See :func:`unittest.mock.Mock.reset_mock()`
"""
super().reset_mock(*args, **kwargs)
self.awaited = _AwaitEvent(self)
self.await_count = 0


def create_autospec(spec, spec_set=False, instance=False, _parent=None,
_name=None, **kwargs):
Expand All @@ -453,7 +606,7 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
If ``spec`` is a coroutine function, and ``instance`` is not ``False``, a
:exc:`RuntimeError` is raised.

versionadded:: 0.12
.. versionadded:: 0.12
"""
if unittest.mock._is_list(spec):
spec = type(spec)
Expand Down Expand Up @@ -495,6 +648,9 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
name=_name, **_kwargs)

if isinstance(spec, unittest.mock.FunctionTypes):
wrapped_mock = mock
# _set_signature returns an object wrapping the mock, not the mock
# itself.
mock = unittest.mock._set_signature(mock, spec)
if is_coroutine_func:
# Can't wrap the mock with asyncio.coroutine because it doesn't
Expand All @@ -503,6 +659,17 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
# _set_signature returns the result of the CoroutineMock itself,
# which is a Coroutine (as defined in CoroutineMock._mock_call)
mock._is_coroutine = _is_coroutine
mock.awaited = _AwaitEvent(mock)
mock.await_count = 0

def assert_awaited(*args, **kwargs):
return wrapped_mock.assert_awaited(*args, **kwargs)

def assert_not_awaited(*args, **kwargs):
return wrapped_mock.assert_not_awaited(*args, **kwargs)

mock.assert_awaited = assert_awaited
mock.assert_not_awaited = assert_not_awaited
else:
unittest.mock._check_signature(spec, mock, is_type, instance)

Expand Down Expand Up @@ -718,11 +885,11 @@ def send(self, value):
if patching.scope == LIMITED]
return super().send(value)

def throw(self, exc):
def throw(self, exc, value=None, traceback=None):
with contextlib.ExitStack() as stack:
[stack.enter_context(patching) for patching in self.patchings
if patching.scope == LIMITED]
return self.gen.throw(exc)
return self.gen.throw(exc, value, traceback)

def __del__(self):
# The generator/coroutine is deleted before it terminated, we must
Expand Down
14 changes: 11 additions & 3 deletions asynctest/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,19 @@ def get_registered_events(selector):
return set(watched_events)


def _format_callback(handle):
if hasattr(asyncio.events, "_format_args_and_kwargs"):
if hasattr(asyncio, "format_helpers"):
# Python 3.7+
def _format_callback(handle):
return asyncio.format_helpers._format_callback(handle._callback,
handle._args, None)
elif hasattr(asyncio.events, "_format_args_and_kwargs"):
# Python 3.5, 3.6
def _format_callback(handle):
return asyncio.events._format_callback(handle._callback, handle._args,
None)
else:
else:
# Python 3.4
def _format_callback(handle):
return asyncio.events._format_callback(handle._callback, handle._args)


Expand Down
Loading