Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

automatically reconnect pubsub when reading messages in blocking mode #2281

Merged
12 changes: 9 additions & 3 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,15 @@ async def parse_response(self, block: bool = True, timeout: float = 0):

await self.check_health()

if not block and not await self._execute(conn, conn.can_read, timeout=timeout):
return None
response = await self._execute(conn, conn.read_response)
async def try_read():
if not block:
if not await conn.can_read(timeout=timeout):
return None
else:
await conn.connect()
return await conn.read_response()

response = await self._execute(conn, try_read)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down
12 changes: 9 additions & 3 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,9 +1497,15 @@ def parse_response(self, block=True, timeout=0):

self.check_health()

if not block and not self._execute(conn, conn.can_read, timeout=timeout):
return None
response = self._execute(conn, conn.read_response)
def try_read():
if not block:
if not conn.can_read(timeout=timeout):
return None
else:
conn.connect()
return conn.read_response()

response = self._execute(conn, try_read)

if self.is_health_check_response(response):
# ignore the health check message as user might not expect it
Expand Down
20 changes: 15 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,25 @@ def _get_info(redis_url):


def pytest_sessionstart(session):
# during test discovery, e.g. with VS Code, we may not
# have a server running.
redis_url = session.config.getoption("--redis-url")
info = _get_info(redis_url)
version = info["redis_version"]
arch_bits = info["arch_bits"]
cluster_enabled = info["cluster_enabled"]
try:
info = _get_info(redis_url)
version = info["redis_version"]
arch_bits = info["arch_bits"]
cluster_enabled = info["cluster_enabled"]
enterprise = info["enterprise"]
except redis.ConnectionError:
# provide optimistic defaults
version = "10.0.0"
arch_bits = 64
cluster_enabled = False
enterprise = False
REDIS_INFO["version"] = version
REDIS_INFO["arch_bits"] = arch_bits
REDIS_INFO["cluster_enabled"] = cluster_enabled
REDIS_INFO["enterprise"] = info["enterprise"]
REDIS_INFO["enterprise"] = enterprise
# store REDIS_INFO in config so that it is available from "condition strings"
session.config.REDIS_INFO = REDIS_INFO

Expand Down
9 changes: 9 additions & 0 deletions tests/test_asyncio/compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import asyncio
import sys
from unittest import mock

try:
mock.AsyncMock
except AttributeError:
import mock


def create_task(coroutine):
if sys.version_info[:2] >= (3, 7):
return asyncio.create_task(coroutine)
else:
return asyncio.ensure_future(coroutine)
130 changes: 129 additions & 1 deletion tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import functools
import socket
from typing import Optional

import async_timeout
Expand All @@ -11,7 +12,7 @@
from redis.typing import EncodableT
from tests.conftest import skip_if_server_version_lt

from .compat import mock
from .compat import create_task, mock


def with_timeout(t):
Expand Down Expand Up @@ -786,3 +787,130 @@ def callback(message):
"pattern": None,
"type": "message",
}


# @pytest.mark.xfail
@pytest.mark.parametrize("method", ["get_message", "listen"])
@pytest.mark.onlynoncluster
class TestPubSubAutoReconnect:
timeout = 2

async def mysetup(self, r, method):
self.messages = asyncio.Queue()
self.pubsub = r.pubsub()
# State: 0 = initial state , 1 = after disconnect, 2 = ConnectionError is seen,
# 3=successfully reconnected 4 = exit
self.state = 0
self.cond = asyncio.Condition()
if method == "get_message":
self.get_message = self.loop_step_get_message
else:
self.get_message = self.loop_step_listen

self.task = create_task(self.loop())
# get the initial connect message
message = await self.messages.get()
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}

async def mycleanup(self):
message = await self.messages.get()
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}
# kill thread
async with self.cond:
self.state = 4 # quit
await self.task

async def test_reconnect_socket_error(self, r: redis.Redis, method):
"""
Test that a socket error will cause reconnect
"""
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
assert self.state == 0
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
mockobj.read_response.side_effect = socket.error
mockobj.can_read.side_effect = socket.error
# wait until task noticies the disconnect until we undo the patch
await self.cond.wait_for(lambda: self.state >= 2)
assert not self.pubsub.connection.is_connected
# it is in a disconnecte state
# wait for reconnect
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
assert self.state == 3

await self.mycleanup()

async def test_reconnect_disconnect(self, r: redis.Redis, method):
"""
Test that a manual disconnect() will cause reconnect
"""
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
self.state = 1
await self.pubsub.connection.disconnect()
assert not self.pubsub.connection.is_connected
# wait for reconnect
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
assert self.state == 3

await self.mycleanup()

async def loop(self):
# reader loop, performing state transitions as it
# discovers disconnects and reconnects
await self.pubsub.subscribe("foo")
while True:
await asyncio.sleep(0.01) # give main thread chance to get lock
async with self.cond:
old_state = self.state
try:
if self.state == 4:
break
# print("state a ", self.state)
got_msg = await self.get_message()
assert got_msg
if self.state in (1, 2):
self.state = 3 # successful reconnect
except redis.ConnectionError:
assert self.state in (1, 2)
self.state = 2 # signal that we noticed the disconnect
finally:
self.cond.notify()
# make sure that we did notice the connection error
# or reconnected without any error
if old_state == 1:
assert self.state in (2, 3)

async def loop_step_get_message(self):
# get a single message via get_message
message = await self.pubsub.get_message(timeout=0.1)
# print(message)
if message is not None:
await self.messages.put(message)
return True
return False

async def loop_step_listen(self):
# get a single message via listen()
try:
async with async_timeout.timeout(0.1):
async for message in self.pubsub.listen():
await self.messages.put(message)
return True
except asyncio.TimeoutError:
return False
127 changes: 127 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import platform
import queue
import socket
import threading
import time
from unittest import mock
Expand Down Expand Up @@ -608,3 +610,128 @@ def test_pubsub_deadlock(self, master_host):
p = r.pubsub()
p.subscribe("my-channel-1", "my-channel-2")
pool.reset()


@pytest.mark.timeout(5, method="thread")
@pytest.mark.parametrize("method", ["get_message", "listen"])
@pytest.mark.onlynoncluster
class TestPubSubAutoReconnect:
def mysetup(self, r, method):
self.messages = queue.Queue()
self.pubsub = r.pubsub()
self.state = 0
self.cond = threading.Condition()
if method == "get_message":
self.get_message = self.loop_step_get_message
else:
self.get_message = self.loop_step_listen

self.thread = threading.Thread(target=self.loop)
self.thread.daemon = True
self.thread.start()
# get the initial connect message
message = self.messages.get(timeout=1)
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}

def wait_for_reconnect(self):
self.cond.wait_for(lambda: self.pubsub.connection._sock is not None, timeout=2)
assert self.pubsub.connection._sock is not None # we didn't time out
assert self.state == 3

message = self.messages.get(timeout=1)
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}

def mycleanup(self):
# kill thread
with self.cond:
self.state = 4 # quit
self.cond.notify()
self.thread.join()

def test_reconnect_socket_error(self, r: redis.Redis, method):
"""
Test that a socket error will cause reconnect
"""
self.mysetup(r, method)
try:
# now, disconnect the connection, and wait for it to be re-established
with self.cond:
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
mockobj.read_response.side_effect = socket.error
mockobj.can_read.side_effect = socket.error
# wait until thread notices the disconnect until we undo the patch
self.cond.wait_for(lambda: self.state >= 2)
assert (
self.pubsub.connection._sock is None
) # it is in a disconnected state
self.wait_for_reconnect()

finally:
self.mycleanup()

def test_reconnect_disconnect(self, r: redis.Redis, method):
"""
Test that a manual disconnect() will cause reconnect
"""
self.mysetup(r, method)
try:
# now, disconnect the connection, and wait for it to be re-established
with self.cond:
self.state = 1
self.pubsub.connection.disconnect()
assert self.pubsub.connection._sock is None
# wait for reconnect
self.wait_for_reconnect()
finally:
self.mycleanup()

def loop(self):
# reader loop, performing state transitions as it
# discovers disconnects and reconnects
self.pubsub.subscribe("foo")
while True:
time.sleep(0.01) # give main thread chance to get lock
with self.cond:
old_state = self.state
try:
if self.state == 4:
break
# print ('state, %s, sock %s' % (state, pubsub.connection._sock))
got_msg = self.get_message()
assert got_msg
if self.state in (1, 2):
self.state = 3 # successful reconnect
except redis.ConnectionError:
assert self.state in (1, 2)
self.state = 2
finally:
self.cond.notify()
# assert that we noticed a connect error, or automatically
# reconnected without error
if old_state == 1:
assert self.state in (2, 3)

def loop_step_get_message(self):
# get a single message via listen()
message = self.pubsub.get_message(timeout=0.1)
if message is not None:
self.messages.put(message)
return True
return False

def loop_step_listen(self):
# get a single message via listen()
for message in self.pubsub.listen():
self.messages.put(message)
return True