diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index d28907ac858..3d46ed544f6 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -1,5 +1,6 @@ import sys import asyncio +import json import warnings from . import hdrs @@ -285,6 +286,15 @@ def receive_bytes(self): msg.data)) return msg.data + @asyncio.coroutine + def receive_json(self, *, loads=json.loads): + msg = yield from self.receive() + if msg.tp != MsgType.text: + raise TypeError( + "Received message {}:{!r} is not str".format(msg.tp, msg.data) + ) + return msg.json(loads=loads) + def write(self, data): raise RuntimeError("Cannot call .write() for websocket") diff --git a/aiohttp/websocket.py b/aiohttp/websocket.py index 2ee0eff2144..f78ba3e0058 100644 --- a/aiohttp/websocket.py +++ b/aiohttp/websocket.py @@ -4,6 +4,7 @@ import binascii import collections import hashlib +import json import os import random import sys @@ -57,7 +58,6 @@ hdrs.SEC_WEBSOCKET_KEY, hdrs.SEC_WEBSOCKET_PROTOCOL) -Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) UNPACK_LEN2 = Struct('!H').unpack_from UNPACK_LEN3 = Struct('!Q').unpack_from @@ -69,6 +69,18 @@ MSG_SIZE = 2 ** 14 +_MessageBase = collections.namedtuple('Message', ['tp', 'data', 'extra']) + + +class Message(_MessageBase): + def json(self, *, loads=json.loads): + """Return parsed JSON data. + + .. versionadded:: 0.22 + """ + return loads(self.data) + + class WebSocketError(Exception): """WebSocket protocol parser error.""" diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 7219e5673e5..601af50638b 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -260,9 +260,9 @@ like one using :meth:`Request.copy`. async def json(self, *, loads=json.loads): body = await self.text() - return loader(body) + return loads(body) - :param callable loader: any :term:`callable` that accepts + :param callable loads: any :term:`callable` that accepts :class:`str` and returns :class:`dict` with parsed JSON (:func:`json.loads` by default). @@ -894,7 +894,7 @@ WebSocketResponse .. coroutinemethod:: receive_str() - A :ref:`coroutine` that calls :meth:`receive_mgs` but + A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is :const:`~aiohttp.websocket.MSG_TEXT`. @@ -904,7 +904,7 @@ WebSocketResponse .. coroutinemethod:: receive_bytes() - A :ref:`coroutine` that calls :meth:`receive_mgs` but + A :ref:`coroutine` that calls :meth:`receive` but also asserts the message type is :const:`~aiohttp.websocket.MSG_BINARY`. @@ -912,6 +912,24 @@ WebSocketResponse :raise TypeError: if message is :const:`~aiohttp.websocket.MSG_TEXT`. + .. coroutinemethod:: receive_json(*, loads=json.loads) + + A :ref:`coroutine` that calls :meth:`receive`, asserts the + message type is :const:`~aiohttp.websocket.MSG_TEXT`, and loads the JSON + string to a Python dict. + + :param callable loads: any :term:`callable` that accepts + :class:`str` and returns :class:`dict` + with parsed JSON (:func:`json.loads` by + default). + + :return dict: loaded JSON content + + :raise TypeError: if message is :const:`~aiohttp.websocket.MSG_BINARY`. + :raise ValueError: if message is not valid JSON. + + .. versionadded:: 0.22 + .. versionadded:: 0.14 diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index cf3dd2bcf8d..e7c5a5f6d55 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1,439 +1,87 @@ -import asyncio -import base64 -import hashlib -import os -import socket -import unittest - -import aiohttp -from aiohttp import helpers, web, websocket - - -WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - - -class TestWebWebSocketFunctional(unittest.TestCase): +"""HTTP websocket server functional tests""" - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def find_unused_port(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('127.0.0.1', 0)) - port = s.getsockname()[1] - s.close() - return port - - @asyncio.coroutine - def create_server(self, method, path, handler): - app = web.Application(loop=self.loop) - app.router.add_route(method, path, handler) +import asyncio +import pytest +from aiohttp import web - port = self.find_unused_port() - srv = yield from self.loop.create_server( - app.make_handler(), '127.0.0.1', port) - url = "http://127.0.0.1:{}".format(port) + path - self.addCleanup(srv.close) - return app, srv, url +@pytest.mark.run_loop +def test_websocket_json(create_app_and_client): @asyncio.coroutine - def connect_ws(self, url, protocol=None): - sec_key = base64.b64encode(os.urandom(16)) - - conn = aiohttp.TCPConnector(loop=self.loop) - self.addCleanup(conn.close) - - headers = { - 'UPGRADE': 'WebSocket', - 'CONNECTION': 'Upgrade', - 'SEC-WEBSOCKET-VERSION': '13', - 'SEC-WEBSOCKET-KEY': sec_key.decode(), - } - - if protocol: - headers['SEC-WEBSOCKET-PROTOCOL'] = protocol - - # send request - response = yield from aiohttp.request( - 'get', url, - headers=headers, - connector=conn, - loop=self.loop) - self.addCleanup(response.close, True) - - self.assertEqual(101, response.status) - self.assertEqual(response.headers.get('upgrade', '').lower(), - 'websocket') - self.assertEqual(response.headers.get('connection', '').lower(), - 'upgrade') - - key = response.headers.get('sec-websocket-accept', '').encode() - match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) - self.assertEqual(key, match) - - # switch to websocket protocol - connection = response.connection - reader = connection.reader.set_parser(websocket.WebSocketParser) - writer = websocket.WebSocketWriter(connection.writer) - - return response, reader, writer - - def test_send_recv_text(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - yield from ws.prepare(request) - msg = yield from ws.receive_str() - ws.send_str(msg+'/answer') - yield from ws.close() - closed.set_result(1) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url) - writer.send('ask') - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_TEXT) - self.assertEqual('ask/answer', msg.data) - - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_CLOSE) - self.assertEqual(msg.data, 1000) - self.assertEqual(msg.extra, '') - - writer.close() - - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_send_recv_bytes(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - yield from ws.prepare(request) - - msg = yield from ws.receive_bytes() - ws.send_bytes(msg+b'/answer') - yield from ws.close() - closed.set_result(1) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url) - writer.send(b'ask', binary=True) - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_BINARY) - self.assertEqual(b'ask/answer', msg.data) - - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_CLOSE) - self.assertEqual(msg.data, 1000) - self.assertEqual(msg.extra, '') - - writer.close() - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_auto_pong_with_closing_by_peer(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.receive() - - msg = yield from ws.receive() - self.assertEqual(msg.tp, web.MsgType.close) - self.assertEqual(msg.data, 1000) - self.assertEqual(msg.extra, 'exit message') - closed.set_result(None) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url) - writer.ping() - writer.send('ask') - - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_PONG) - writer.close(1000, 'exit message') - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_ping(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - yield from ws.prepare(request) + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + msg = yield from ws.receive() - ws.ping('data') - yield from ws.receive() - closed.set_result(None) - return ws + msg_json = msg.json() + answer = msg_json['test'] + ws.send_str(answer) - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url) - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_PING) - self.assertEqual(msg.data, b'data') - writer.pong() - writer.close(2, 'exit message') - yield from closed - resp.close() + yield from ws.close() + return ws - self.loop.run_until_complete(go()) + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) - def test_client_ping(self): + ws = yield from client.ws_connect('/') + expected_value = 'value' + payload = '{"test": "%s"}' % expected_value + ws.send_str(payload) - closed = helpers.create_future(self.loop) + resp = yield from ws.receive() + assert resp.data == expected_value - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - yield from ws.prepare(request) - yield from ws.receive() - closed.set_result(None) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url) - writer.ping('data') - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_PONG) - self.assertEqual(msg.data, b'data') - writer.pong() - writer.close() - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_pong(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse(autoping=False) - yield from ws.prepare(request) - - msg = yield from ws.receive() - self.assertEqual(msg.tp, web.MsgType.ping) - ws.pong('data') - - msg = yield from ws.receive() - self.assertEqual(msg.tp, web.MsgType.close) - self.assertEqual(msg.data, 1000) - self.assertEqual(msg.extra, 'exit message') - closed.set_result(None) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url) - writer.ping('data') - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_PONG) - self.assertEqual(msg.data, b'data') - writer.close(1000, 'exit message') - - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_change_status(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - ws.set_status(200) - self.assertEqual(200, ws.status) - yield from ws.prepare(request) - self.assertEqual(101, ws.status) - yield from ws.close() - closed.set_result(None) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, _, writer = yield from self.connect_ws(url) - writer.close() - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_handle_protocol(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse(protocols=('foo', 'bar')) - yield from ws.prepare(request) - yield from ws.close() - self.assertEqual('bar', ws.protocol) - closed.set_result(None) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, _, writer = yield from self.connect_ws(url, 'eggs, bar') - writer.close() - - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_server_close_handshake(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse(protocols=('foo', 'bar')) - yield from ws.prepare(request) - yield from ws.close() - closed.set_result(None) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url, 'eggs, bar') - - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_CLOSE) - writer.close() - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_client_close_handshake(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse( - autoclose=False, protocols=('foo', 'bar')) - yield from ws.prepare(request) - - msg = yield from ws.receive() - self.assertEqual(msg.tp, web.MsgType.close) - self.assertFalse(ws.closed) - yield from ws.close() - self.assertTrue(ws.closed) - self.assertEqual(ws.close_code, 1007) - - msg = yield from ws.receive() - self.assertEqual(msg.tp, web.MsgType.closed) - - closed.set_result(None) - return ws - - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp, reader, writer = yield from self.connect_ws(url, 'eggs, bar') - - writer.close(code=1007) - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_CLOSE) - yield from closed - resp.close() - - self.loop.run_until_complete(go()) - - def test_server_close_handshake_server_eats_client_messages(self): - - closed = helpers.create_future(self.loop) - - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse(protocols=('foo', 'bar')) - yield from ws.prepare(request) +@pytest.mark.run_loop +def test_websocket_json_invalid_message(create_app_and_client): + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + msg = yield from ws.receive() + + try: + msg.json() + except ValueError: + ws.send_str("ValueError raised: '%s'" % msg.data) + else: + raise Exception("No ValueError was raised") + finally: yield from ws.close() - closed.set_result(None) - return ws + return ws - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - response, reader, writer = yield from self.connect_ws( - url, 'eggs, bar') + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) - msg = yield from reader.read() - self.assertEqual(msg.tp, websocket.MSG_CLOSE) + ws = yield from client.ws_connect('/') + payload = 'NOT A VALID JSON STRING' + ws.send_str(payload) - writer.send('text') - writer.send(b'bytes', binary=True) - writer.ping() + resp = yield from ws.receive() + assert payload in resp.data - writer.close() - yield from closed - response.close() +@pytest.mark.run_loop +def test_websocket_receive_json(create_app_and_client): + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) - self.loop.run_until_complete(go()) + data = yield from ws.receive_json() + answer = data['test'] + ws.send_str(answer) - def test_receive_msg(self): - @asyncio.coroutine - def handler(request): - ws = web.WebSocketResponse() - yield from ws.prepare(request) + yield from ws.close() + return ws - with self.assertWarns(DeprecationWarning): - msg = yield from ws.receive_msg() - self.assertEqual(msg.data, b'data') - yield from ws.close() - return ws + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) - @asyncio.coroutine - def go(): - _, _, url = yield from self.create_server('GET', '/', handler) - resp = yield from aiohttp.ws_connect(url, loop=self.loop) - resp.send_bytes(b'data') - yield from resp.close() + ws = yield from client.ws_connect('/') + expected_value = 'value' + payload = '{"test": "%s"}' % expected_value + ws.send_str(payload) - self.loop.run_until_complete(go()) + resp = yield from ws.receive() + assert resp.data == expected_value diff --git a/tests/test_web_websocket_functional_oldstyle.py b/tests/test_web_websocket_functional_oldstyle.py new file mode 100644 index 00000000000..cf3dd2bcf8d --- /dev/null +++ b/tests/test_web_websocket_functional_oldstyle.py @@ -0,0 +1,439 @@ +import asyncio +import base64 +import hashlib +import os +import socket +import unittest + +import aiohttp +from aiohttp import helpers, web, websocket + + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +class TestWebWebSocketFunctional(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def find_unused_port(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + return port + + @asyncio.coroutine + def create_server(self, method, path, handler): + app = web.Application(loop=self.loop) + app.router.add_route(method, path, handler) + + port = self.find_unused_port() + srv = yield from self.loop.create_server( + app.make_handler(), '127.0.0.1', port) + url = "http://127.0.0.1:{}".format(port) + path + self.addCleanup(srv.close) + return app, srv, url + + @asyncio.coroutine + def connect_ws(self, url, protocol=None): + sec_key = base64.b64encode(os.urandom(16)) + + conn = aiohttp.TCPConnector(loop=self.loop) + self.addCleanup(conn.close) + + headers = { + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + } + + if protocol: + headers['SEC-WEBSOCKET-PROTOCOL'] = protocol + + # send request + response = yield from aiohttp.request( + 'get', url, + headers=headers, + connector=conn, + loop=self.loop) + self.addCleanup(response.close, True) + + self.assertEqual(101, response.status) + self.assertEqual(response.headers.get('upgrade', '').lower(), + 'websocket') + self.assertEqual(response.headers.get('connection', '').lower(), + 'upgrade') + + key = response.headers.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + self.assertEqual(key, match) + + # switch to websocket protocol + connection = response.connection + reader = connection.reader.set_parser(websocket.WebSocketParser) + writer = websocket.WebSocketWriter(connection.writer) + + return response, reader, writer + + def test_send_recv_text(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + msg = yield from ws.receive_str() + ws.send_str(msg+'/answer') + yield from ws.close() + closed.set_result(1) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url) + writer.send('ask') + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_TEXT) + self.assertEqual('ask/answer', msg.data) + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + self.assertEqual(msg.data, 1000) + self.assertEqual(msg.extra, '') + + writer.close() + + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_send_recv_bytes(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + msg = yield from ws.receive_bytes() + ws.send_bytes(msg+b'/answer') + yield from ws.close() + closed.set_result(1) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url) + writer.send(b'ask', binary=True) + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_BINARY) + self.assertEqual(b'ask/answer', msg.data) + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + self.assertEqual(msg.data, 1000) + self.assertEqual(msg.extra, '') + + writer.close() + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_auto_pong_with_closing_by_peer(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + yield from ws.receive() + + msg = yield from ws.receive() + self.assertEqual(msg.tp, web.MsgType.close) + self.assertEqual(msg.data, 1000) + self.assertEqual(msg.extra, 'exit message') + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url) + writer.ping() + writer.send('ask') + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_PONG) + writer.close(1000, 'exit message') + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_ping(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + ws.ping('data') + yield from ws.receive() + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url) + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_PING) + self.assertEqual(msg.data, b'data') + writer.pong() + writer.close(2, 'exit message') + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_client_ping(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + yield from ws.receive() + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url) + writer.ping('data') + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_PONG) + self.assertEqual(msg.data, b'data') + writer.pong() + writer.close() + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_pong(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(autoping=False) + yield from ws.prepare(request) + + msg = yield from ws.receive() + self.assertEqual(msg.tp, web.MsgType.ping) + ws.pong('data') + + msg = yield from ws.receive() + self.assertEqual(msg.tp, web.MsgType.close) + self.assertEqual(msg.data, 1000) + self.assertEqual(msg.extra, 'exit message') + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url) + writer.ping('data') + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_PONG) + self.assertEqual(msg.data, b'data') + writer.close(1000, 'exit message') + + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_change_status(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + ws.set_status(200) + self.assertEqual(200, ws.status) + yield from ws.prepare(request) + self.assertEqual(101, ws.status) + yield from ws.close() + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, _, writer = yield from self.connect_ws(url) + writer.close() + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_handle_protocol(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(protocols=('foo', 'bar')) + yield from ws.prepare(request) + yield from ws.close() + self.assertEqual('bar', ws.protocol) + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, _, writer = yield from self.connect_ws(url, 'eggs, bar') + writer.close() + + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_server_close_handshake(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(protocols=('foo', 'bar')) + yield from ws.prepare(request) + yield from ws.close() + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url, 'eggs, bar') + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + writer.close() + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_client_close_handshake(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse( + autoclose=False, protocols=('foo', 'bar')) + yield from ws.prepare(request) + + msg = yield from ws.receive() + self.assertEqual(msg.tp, web.MsgType.close) + self.assertFalse(ws.closed) + yield from ws.close() + self.assertTrue(ws.closed) + self.assertEqual(ws.close_code, 1007) + + msg = yield from ws.receive() + self.assertEqual(msg.tp, web.MsgType.closed) + + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url, 'eggs, bar') + + writer.close(code=1007) + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + + def test_server_close_handshake_server_eats_client_messages(self): + + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse(protocols=('foo', 'bar')) + yield from ws.prepare(request) + yield from ws.close() + closed.set_result(None) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + response, reader, writer = yield from self.connect_ws( + url, 'eggs, bar') + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + + writer.send('text') + writer.send(b'bytes', binary=True) + writer.ping() + + writer.close() + yield from closed + + response.close() + + self.loop.run_until_complete(go()) + + def test_receive_msg(self): + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + with self.assertWarns(DeprecationWarning): + msg = yield from ws.receive_msg() + self.assertEqual(msg.data, b'data') + yield from ws.close() + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp = yield from aiohttp.ws_connect(url, loop=self.loop) + resp.send_bytes(b'data') + yield from resp.close() + + self.loop.run_until_complete(go())