Skip to content

Commit

Permalink
asyncio: Refactor tests: add a base TestCase class
Browse files Browse the repository at this point in the history
  • Loading branch information
vstinner committed Jun 17, 2014
1 parent d6f02fc commit c73701d
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 219 deletions.
18 changes: 18 additions & 0 deletions Lib/asyncio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tempfile
import threading
import time
import unittest
from unittest import mock

from http.server import HTTPServer
Expand Down Expand Up @@ -379,3 +380,20 @@ def get_function_source(func):
if source is None:
raise ValueError("unable to get the source of %r" % (func,))
return source


class TestCase(unittest.TestCase):
def set_event_loop(self, loop, *, cleanup=True):
assert loop is not None
# ensure that the event loop is passed explicitly in asyncio
events.set_event_loop(None)
if cleanup:
self.addCleanup(loop.close)

def new_test_loop(self, gen=None):
loop = TestLoop(gen)
self.set_event_loop(loop)
return loop

def tearDown(self):
events.set_event_loop(None)
11 changes: 4 additions & 7 deletions Lib/test/test_asyncio/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
PY34 = sys.version_info >= (3, 4)


class BaseEventLoopTests(unittest.TestCase):
class BaseEventLoopTests(test_utils.TestCase):

def setUp(self):
self.loop = base_events.BaseEventLoop()
self.loop._selector = mock.Mock()
asyncio.set_event_loop(None)
self.set_event_loop(self.loop)

def test_not_implemented(self):
m = mock.Mock()
Expand Down Expand Up @@ -548,14 +548,11 @@ def connection_lost(self, exc):
self.done.set_result(None)


class BaseEventLoopWithSelectorTests(unittest.TestCase):
class BaseEventLoopWithSelectorTests(test_utils.TestCase):

def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.set_event_loop(self.loop)

@mock.patch('asyncio.base_events.socket')
def test_create_connection_multiple_errors(self, m_socket):
Expand Down
14 changes: 7 additions & 7 deletions Lib/test/test_asyncio/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class EventLoopTestsMixin:
def setUp(self):
super().setUp()
self.loop = self.create_event_loop()
asyncio.set_event_loop(None)
self.set_event_loop(self.loop)

def tearDown(self):
# just in case if we have transport close callbacks
Expand Down Expand Up @@ -1629,14 +1629,14 @@ def connect(cmd=None, **kwds):

if sys.platform == 'win32':

class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):

def create_event_loop(self):
return asyncio.SelectorEventLoop()

class ProactorEventLoopTests(EventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):

def create_event_loop(self):
return asyncio.ProactorEventLoop()
Expand Down Expand Up @@ -1691,7 +1691,7 @@ def tearDown(self):
if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):

def create_event_loop(self):
return asyncio.SelectorEventLoop(
Expand All @@ -1716,23 +1716,23 @@ def test_write_pty(self):
if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):

def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.EpollSelector())

if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):

def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.PollSelector())

# Should always exist.
class SelectEventLoopTests(UnixEventLoopTestsMixin,
SubprocessTestsMixin,
unittest.TestCase):
test_utils.TestCase):

def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.SelectSelector())
Expand Down
25 changes: 7 additions & 18 deletions Lib/test/test_asyncio/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,10 @@ def _fakefunc(f):
return f


class FutureTests(unittest.TestCase):
class FutureTests(test_utils.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()

def test_initial_state(self):
f = asyncio.Future(loop=self.loop)
Expand All @@ -30,12 +26,9 @@ def test_initial_state(self):
self.assertTrue(f.cancelled())

def test_init_constructor_default_loop(self):
try:
asyncio.set_event_loop(self.loop)
f = asyncio.Future()
self.assertIs(f._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
f = asyncio.Future()
self.assertIs(f._loop, self.loop)

def test_constructor_positional(self):
# Make sure Future doesn't accept a positional argument
Expand Down Expand Up @@ -264,14 +257,10 @@ def test_wrap_future_cancel2(self):
self.assertTrue(f2.cancelled())


class FutureDoneCallbackTests(unittest.TestCase):
class FutureDoneCallbackTests(test_utils.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()

def run_briefly(self):
test_utils.run_briefly(self.loop)
Expand Down
68 changes: 20 additions & 48 deletions Lib/test/test_asyncio/test_locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
RGX_REPR = re.compile(STR_RGX_REPR)


class LockTests(unittest.TestCase):
class LockTests(test_utils.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
loop = mock.Mock()
Expand All @@ -35,12 +31,9 @@ def test_ctor_loop(self):
self.assertIs(lock._loop, self.loop)

def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
lock = asyncio.Lock()
self.assertIs(lock._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
lock = asyncio.Lock()
self.assertIs(lock._loop, self.loop)

def test_repr(self):
lock = asyncio.Lock(loop=self.loop)
Expand Down Expand Up @@ -240,14 +233,10 @@ def test_context_manager_no_yield(self):
self.assertFalse(lock.locked())


class EventTests(unittest.TestCase):
class EventTests(test_utils.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
loop = mock.Mock()
Expand All @@ -258,12 +247,9 @@ def test_ctor_loop(self):
self.assertIs(ev._loop, self.loop)

def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
ev = asyncio.Event()
self.assertIs(ev._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
ev = asyncio.Event()
self.assertIs(ev._loop, self.loop)

def test_repr(self):
ev = asyncio.Event(loop=self.loop)
Expand Down Expand Up @@ -376,14 +362,10 @@ def c1(result):
self.assertTrue(t.result())


class ConditionTests(unittest.TestCase):
class ConditionTests(test_utils.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
loop = mock.Mock()
Expand All @@ -394,12 +376,9 @@ def test_ctor_loop(self):
self.assertIs(cond._loop, self.loop)

def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
cond = asyncio.Condition()
self.assertIs(cond._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
cond = asyncio.Condition()
self.assertIs(cond._loop, self.loop)

def test_wait(self):
cond = asyncio.Condition(loop=self.loop)
Expand Down Expand Up @@ -678,14 +657,10 @@ def test_context_manager_no_yield(self):
self.assertFalse(cond.locked())


class SemaphoreTests(unittest.TestCase):
class SemaphoreTests(test_utils.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()
self.loop = self.new_test_loop()

def test_ctor_loop(self):
loop = mock.Mock()
Expand All @@ -696,12 +671,9 @@ def test_ctor_loop(self):
self.assertIs(sem._loop, self.loop)

def test_ctor_noloop(self):
try:
asyncio.set_event_loop(self.loop)
sem = asyncio.Semaphore()
self.assertIs(sem._loop, self.loop)
finally:
asyncio.set_event_loop(None)
asyncio.set_event_loop(self.loop)
sem = asyncio.Semaphore()
self.assertIs(sem._loop, self.loop)

def test_initial_value_zero(self):
sem = asyncio.Semaphore(0, loop=self.loop)
Expand Down
7 changes: 4 additions & 3 deletions Lib/test/test_asyncio/test_proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from asyncio import test_utils


class ProactorSocketTransportTests(unittest.TestCase):
class ProactorSocketTransportTests(test_utils.TestCase):

def setUp(self):
self.loop = test_utils.TestLoop()
self.loop = self.new_test_loop()
self.proactor = mock.Mock()
self.loop._proactor = self.proactor
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
Expand Down Expand Up @@ -343,7 +343,7 @@ def test_pause_resume_reading(self):
tr.close()


class BaseProactorEventLoopTests(unittest.TestCase):
class BaseProactorEventLoopTests(test_utils.TestCase):

def setUp(self):
self.sock = mock.Mock(socket.socket)
Expand All @@ -356,6 +356,7 @@ def _socketpair(s):
return (self.ssock, self.csock)

self.loop = EventLoop(self.proactor)
self.set_event_loop(self.loop, cleanup=False)

@mock.patch.object(BaseProactorEventLoop, 'call_soon')
@mock.patch.object(BaseProactorEventLoop, '_socketpair')
Expand Down
Loading

0 comments on commit c73701d

Please sign in to comment.