Skip to content

Commit

Permalink
Don't send AAAA DNS request when domain resolved to IPv4 address (cel…
Browse files Browse the repository at this point in the history
…ery#153)

* Don't patch out whole _connect body in test_AbstractTransport

It should raise test coverage level for the method.

* tests: check that if socket.connect raises, it bubbles up

* tests: cover _connect behavior when multiple entries are found

_connect should continue to iterate through entries returned by
getaddrinfo until it succeeds, or until it depletes the list, at which
point it raises socket.error.

* Don't send AAAA DNS request when domain resolved to IPv4 address

There is no need to send both A and AAAA requests when connecting to a
host since we are interested in a single address only. So if the host
resolves into IPv4 address, we can skip request for AAAA since it will
not be used anyway.

This sounds like a minor optimization, but it comes as a huge win in
case where DNS resolver is not configured for the requested host name,
but the name is listed in /etc/hosts with a IPv4 address. In this case,
resolution is very quick (the file is local, so no DNS request is sent),
except for the fact that we still ask for AAAA record for the host name.
A misconfigured DNS resolver can take time until it will time out the
request (30 seconds is a common length for Linux), which is an
unnecessary delay.

This exact situation hit OpenStack TripleO CI where DNS was not
configured, but resolution is implemented via /etc/hosts file which did
not include IPv6 entries.

OpenStack bug: https://bugs.launchpad.net/neutron/+bug/1696094

* tests: cover recover logic when getaddrinfo raises gaierror

* tests: check NotImplementedError from set_cloexec in _connect

If it's raised, we do nothing.

* Added some more comments in _connect explaining the logic

The logic became a bit hard to follow, so added a bunch of comments
trying to explain the rationale behind the complexity, as well as give
key for intent of specific code blocks.

* _connect: made the socket error message more descriptive

Suggest it's a DNS resolution issue, not a generic connectivity problem.

* Replace an unused variable with _

* _connect: extracted len() results into separate variables
  • Loading branch information
booxter authored and Omer Katz committed Oct 28, 2017
1 parent b59290b commit 1ad97fb
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 23 deletions.
68 changes: 50 additions & 18 deletions amqp/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,58 @@ def having_timeout(self, timeout):
sock.settimeout(prev)

def _connect(self, host, port, timeout):
entries = socket.getaddrinfo(
host, port, 0, socket.SOCK_STREAM, SOL_TCP,
)
for i, res in enumerate(entries):
af, socktype, proto, canonname, sa = res
e = None

# Below we are trying to avoid additional DNS requests for AAAA if A
# succeeds. This helps a lot in case when a hostname has an IPv4 entry
# in /etc/hosts but not IPv6. Without the (arguably somewhat twisted)
# logic below, getaddrinfo would attempt to resolve the hostname for
# both IP versions, which would make the resolver talk to configured
# DNS servers. If those servers are for some reason not available
# during resolution attempt (either because of system misconfiguration,
# or network connectivity problem), resolution process locks the
# _connect call for extended time.
addr_types = (socket.AF_INET, socket.AF_INET6)
addr_types_num = len(addr_types)
for n, family in enumerate(addr_types):
# first, resolve the address for a single address family
try:
self.sock = socket.socket(af, socktype, proto)
entries = socket.getaddrinfo(
host, port, family, socket.SOCK_STREAM, SOL_TCP)
entries_num = len(entries)
except socket.gaierror:
# we may have depleted all our options
if n + 1 >= addr_types_num:
# if getaddrinfo succeeded before for another address
# family, reraise the previous socket.error since it's more
# relevant to users
raise (e
if e is not None
else socket.error(
"failed to resolve broker hostname"))
continue

# now that we have address(es) for the hostname, connect to broker
for i, res in enumerate(entries):
af, socktype, proto, _, sa = res
try:
set_cloexec(self.sock, True)
except NotImplementedError:
pass
self.sock.settimeout(timeout)
self.sock.connect(sa)
except socket.error:
self.sock.close()
self.sock = None
if i + 1 >= len(entries):
raise
else:
break
self.sock = socket.socket(af, socktype, proto)
try:
set_cloexec(self.sock, True)
except NotImplementedError:
pass
self.sock.settimeout(timeout)
self.sock.connect(sa)
except socket.error as e:
if self.sock is not None:
self.sock.close()
self.sock = None
# we may have depleted all our options
if i + 1 >= entries_num and n + 1 >= addr_types_num:
raise
else:
# hurray, we established connection
return

def _init_socket(self, socket_settings, read_timeout, write_timeout):
try:
Expand Down
126 changes: 121 additions & 5 deletions t/unit/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import pytest

from case import Mock, patch
from case import ANY, Mock, call, patch

from amqp import transport
from amqp.exceptions import UnexpectedFrame
Expand All @@ -15,6 +15,11 @@
class MockSocket(object):
options = {}

def __init__(self, *args, **kwargs):
super(MockSocket, self).__init__(*args, **kwargs)
self.connected = False
self.sa = None

def setsockopt(self, family, key, value):
if not isinstance(value, int):
raise socket.error()
Expand All @@ -23,6 +28,20 @@ def setsockopt(self, family, key, value):
def getsockopt(self, family, key):
return self.options.get(key, 0)

def settimeout(self, timeout):
self.timeout = timeout

def fileno(self):
return 10

def connect(self, sa):
self.connected = True
self.sa = sa

def close(self):
self.connected = False
self.sa = None


TCP_KEEPIDLE = 4
TCP_KEEPINTVL = 5
Expand Down Expand Up @@ -188,14 +207,13 @@ class test_AbstractTransport:

class Transport(transport._AbstractTransport):

def _connect(self, *args):
pass

def _init_socket(self, *args):
pass

@pytest.fixture(autouse=True)
def setup_transport(self):
@patch('socket.socket.connect')
def setup_transport(self, patching):
self.connect_mock = patching
self.t = self.Transport('localhost:5672', 10)
self.t.connect()

Expand Down Expand Up @@ -305,6 +323,104 @@ def test_write__EINTR(self):
self.t.write('foo')
assert not self.t.connected

def test_connect_socket_fails(self):
self.t.sock = Mock()
self.t.close()
self.connect_mock.side_effect = socket.error
with pytest.raises(socket.error):
self.t.connect()

@patch('socket.socket', return_value=MockSocket())
@patch('socket.getaddrinfo',
return_value=[
(socket.AF_INET, 1, socket.IPPROTO_TCP,
'', ('127.0.0.1', 5672)),
(socket.AF_INET, 1, socket.IPPROTO_TCP,
'', ('127.0.0.2', 5672))
])
def test_connect_multiple_addr_entries_fails(self, getaddrinfo, sock_mock):
self.t.sock = Mock()
self.t.close()
with patch.object(sock_mock.return_value, 'connect',
side_effect=socket.error):
with pytest.raises(socket.error):
self.t.connect()

@patch('socket.socket', return_value=MockSocket())
@patch('socket.getaddrinfo',
return_value=[
(socket.AF_INET, 1, socket.IPPROTO_TCP,
'', ('127.0.0.1', 5672)),
(socket.AF_INET, 1, socket.IPPROTO_TCP,
'', ('127.0.0.2', 5672))
])
def test_connect_multiple_addr_entries_succeed(self, getaddrinfo,
sock_mock):
self.t.sock = Mock()
self.t.close()
with patch.object(sock_mock.return_value, 'connect',
side_effect=(socket.error, None)):
self.t.connect()

@patch('socket.socket', return_value=MockSocket())
@patch('socket.getaddrinfo',
side_effect=[
[(socket.AF_INET, 1, socket.IPPROTO_TCP,
'', ('127.0.0.1', 5672))],
[(socket.AF_INET6, 1, socket.IPPROTO_TCP,
'', ('::1', 5672))]
])
def test_connect_short_curcuit_on_INET_succeed(self, getaddrinfo,
sock_mock):
self.t.sock = Mock()
self.t.close()
self.t.connect()
getaddrinfo.assert_called_with(
'localhost', 5672, socket.AF_INET, ANY, ANY)

@patch('socket.socket', return_value=MockSocket())
@patch('socket.getaddrinfo',
side_effect=[
[(socket.AF_INET, 1, socket.IPPROTO_TCP,
'', ('127.0.0.1', 5672))],
[(socket.AF_INET6, 1, socket.IPPROTO_TCP,
'', ('::1', 5672))]
])
def test_connect_short_curcuit_on_INET_fails(self, getaddrinfo, sock_mock):
self.t.sock = Mock()
self.t.close()
with patch.object(sock_mock.return_value, 'connect',
side_effect=(socket.error, None)):
self.t.connect()
getaddrinfo.assert_has_calls(
[call('localhost', 5672, addr_type, ANY, ANY)
for addr_type in (socket.AF_INET, socket.AF_INET6)])

@patch('socket.getaddrinfo', side_effect=socket.gaierror)
def test_connect_getaddrinfo_raises_gaierror(self, getaddrinfo):
with pytest.raises(socket.error):
self.t.connect()

@patch('socket.socket', return_value=MockSocket())
@patch('socket.getaddrinfo',
side_effect=[
socket.gaierror,
[(socket.AF_INET6, 1, socket.IPPROTO_TCP,
'', ('::1', 5672))]
])
def test_connect_getaddrinfo_raises_gaierror_once_recovers(self, *mocks):
self.t.connect()

@patch('socket.socket', return_value=MockSocket())
@patch('socket.getaddrinfo',
return_value=[(socket.AF_INET, 1, socket.IPPROTO_TCP,
'', ('127.0.0.1', 5672))])
def test_connect_survives_not_implemented_set_cloexec(self, *mocks):
with patch('amqp.transport.set_cloexec',
side_effect=NotImplementedError) as cloexec_mock:
self.t.connect()
assert cloexec_mock.called


class test_SSLTransport:

Expand Down

0 comments on commit 1ad97fb

Please sign in to comment.