Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shorten too long test UNIX socket path #3832

Merged
merged 14 commits into from
Jun 11, 2019
100 changes: 98 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import hashlib
import os
import socket
import ssl
import sys
from hashlib import md5, sha256
from pathlib import Path
from tempfile import TemporaryDirectory
from uuid import uuid4

import pytest
import trustme

pytest_plugins = ['aiohttp.pytest_plugin', 'pytester']

IS_HPUX = sys.platform.startswith('hp-ux')
"""Specifies whether the current runtime is HP-UX."""
IS_LINUX = sys.platform.startswith('linux')
"""Specifies whether the current runtime is HP-UX."""
IS_UNIX = hasattr(socket, 'AF_UNIX')
"""Specifies whether the current runtime is *NIX."""

needs_unix = pytest.mark.skipif(not IS_UNIX, reason='requires UNIX sockets')


@pytest.fixture
def tls_certificate_authority():
Expand Down Expand Up @@ -55,4 +70,85 @@ def tls_certificate_pem_bytes(tls_certificate):
@pytest.fixture
def tls_certificate_fingerprint_sha256(tls_certificate_pem_bytes):
tls_cert_der = ssl.PEM_cert_to_DER_cert(tls_certificate_pem_bytes.decode())
return hashlib.sha256(tls_cert_der).digest()
return sha256(tls_cert_der).digest()


@pytest.fixture
def unix_sockname(tmp_path, tmp_path_factory):
"""Generate an fs path to the UNIX domain socket for testing.
N.B. Different OS kernels have different fs path length limitations
for it. For Linux, it's 108, for HP-UX it's 92 (or higher) depending
on its version. For for most of the BSDs (Open, Free, macOS) it's
mostly 104 but sometimes it can be down to 100.
Ref: https://github.com/aio-libs/aiohttp/issues/3572
"""
if not IS_UNIX:
pytest.skip('requires UNIX sockets')

max_sock_len = 92 if IS_HPUX else 108 if IS_LINUX else 100
"""Amount of bytes allocated for the UNIX socket path by OS kernel.
Ref: https://unix.stackexchange.com/a/367012/27133
"""

sock_file_name = 'unix.sock'
unique_prefix = '{!s}-'.format(uuid4())
unique_prefix_len = len(unique_prefix.encode())

root_tmp_dir = Path('/tmp').resolve()
os_tmp_dir = Path(os.getenv('TMPDIR', '/tmp')).resolve()
original_base_tmp_path = Path(
str(tmp_path_factory.getbasetemp()),
).resolve()

original_base_tmp_path_hash = md5(
str(original_base_tmp_path).encode(),
).hexdigest()

def make_tmp_dir(base_tmp_dir):
return TemporaryDirectory(
dir=str(base_tmp_dir),
prefix='pt-',
suffix='-{!s}'.format(original_base_tmp_path_hash),
)

def assert_sock_fits(sock_path):
sock_path_len = len(sock_path.encode())
# exit-check to verify that it's correct and simplify debugging
# in the future
assert sock_path_len <= max_sock_len, (
'Suggested UNIX socket ({sock_path}) is {sock_path_len} bytes '
'long but the current kernel only has {max_sock_len} bytes '
'allocated to hold it so it must be shorter. '
'See https://github.com/aio-libs/aiohttp/issues/3572 '
'for more info.'
).format_map(locals())

paths = original_base_tmp_path, os_tmp_dir, root_tmp_dir
unique_paths = [p for n, p in enumerate(paths) if p not in paths[:n]]
paths_num = len(unique_paths)

for num, tmp_dir_path in enumerate(paths, 1):
with make_tmp_dir(tmp_dir_path) as tmpd:
tmpd = Path(tmpd).resolve()
sock_path = str(tmpd / sock_file_name)
sock_path_len = len(sock_path.encode())

if num >= paths_num:
# exit-check to verify that it's correct and simplify
# debugging in the future
assert_sock_fits(sock_path)

if sock_path_len <= max_sock_len:
if max_sock_len - sock_path_len >= unique_prefix_len:
# If we're lucky to have extra space in the path,
# let's also make it more unique
sock_path = str(
tmpd / ''.join((unique_prefix, sock_file_name))
)
# Double-checking it:
assert_sock_fits(sock_path)
yield sock_path
return
14 changes: 3 additions & 11 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from aiohttp.helpers import PY_37
from aiohttp.test_utils import make_mocked_coro, unused_port
from aiohttp.tracing import Trace
from conftest import needs_unix


@pytest.fixture()
Expand All @@ -42,11 +43,6 @@ def ssl_key():
return ConnectionKey('localhost', 80, True, None, None, None, None)


@pytest.fixture
def unix_sockname(tmp_path):
return str(tmp_path / 'socket.sock')


@pytest.fixture
def unix_server(loop, unix_sockname):
runners = []
Expand Down Expand Up @@ -1956,8 +1952,7 @@ async def handler(request):
assert r.status == 200


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason="requires unix socket")
@needs_unix
async def test_unix_connector_not_found(loop) -> None:
connector = aiohttp.UnixConnector('/' + uuid.uuid4().hex, loop=loop)

Expand All @@ -1968,8 +1963,7 @@ async def test_unix_connector_not_found(loop) -> None:
await connector.connect(req, None, ClientTimeout())


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason="requires unix socket")
@needs_unix
async def test_unix_connector_permission(loop) -> None:
loop.create_unix_connection = make_mocked_coro(
raise_exception=PermissionError())
Expand Down Expand Up @@ -2094,8 +2088,6 @@ async def handler(request):
conn.close()


@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason='requires UNIX sockets')
async def test_unix_connector(unix_server, unix_sockname) -> None:
async def handler(request):
return web.Response()
Expand Down
9 changes: 3 additions & 6 deletions tests/test_web_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,14 @@ def test_non_app() -> None:
web.AppRunner(object())


@pytest.mark.skipif(platform.system() == "Windows",
reason="Unix socket support is required")
async def test_addresses(make_runner, tmpdir) -> None:
async def test_addresses(make_runner, unix_sockname) -> None:
_sock = get_unused_port_socket('127.0.0.1')
runner = make_runner()
await runner.setup()
tcp = web.SockSite(runner, _sock)
await tcp.start()
path = str(tmpdir / 'tmp.sock')
unix = web.UnixSite(runner, path)
unix = web.UnixSite(runner, unix_sockname)
await unix.start()
actual_addrs = runner.addresses
expected_host, expected_post = _sock.getsockname()[:2]
assert actual_addrs == [(expected_host, expected_post), path]
assert actual_addrs == [(expected_host, expected_post), unix_sockname]