diff --git a/redis/__init__.py b/redis/__init__.py index 7bf6839453..495d2d99bb 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -2,7 +2,6 @@ from redis import asyncio # noqa from redis.backoff import default_backoff -from redis.cache import _LocalChace from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster from redis.connection import ( @@ -62,7 +61,6 @@ def int_or_str(value): VERSION = tuple([99, 99, 99]) __all__ = [ - "_LocalChace", "AuthenticationError", "AuthenticationWrongNumberOfArgsError", "BlockingConnectionPool", diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 569e7ee679..13aa1ffccb 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -6,15 +6,18 @@ from .base import _AsyncRESPBase, _RESPBase from .socket import SERVER_CLOSED_CONNECTION_ERROR +_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] + class _RESP3Parser(_RESPBase): """RESP3 protocol implementation""" def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.push_handler_func = self.handle_push_response + self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.invalidations_push_handler_func = None - def handle_push_response(self, response): + def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -114,13 +117,7 @@ def _read_response(self, disable_decoding=False, push_request=False): ) for _ in range(int(response)) ] - res = self.push_handler_func(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + self.handle_push_response(response, disable_decoding, push_request) else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -128,16 +125,32 @@ def _read_response(self, disable_decoding=False, push_request=False): response = self.encoder.decode(response) return response - def set_push_handler(self, push_handler_func): - self.push_handler_func = push_handler_func + def handle_push_response(self, response, disable_decoding, push_request): + if response[0] in _INVALIDATION_MESSAGE: + res = self.invalidation_push_handler_func(response) + else: + res = self.pubsub_push_handler_func(response) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + + def set_pubsub_push_handler(self, pubsub_push_handler_func): + self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidations_push_handler_func): + self.invalidation_push_handler_func = invalidations_push_handler_func class _AsyncRESP3Parser(_AsyncRESPBase): def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.push_handler_func = self.handle_push_response + self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.invalidations_push_handler_func = None - def handle_push_response(self, response): + def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -246,13 +259,7 @@ async def _read_response( ) for _ in range(int(response)) ] - res = self.push_handler_func(response) - if not push_request: - return await self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + await self.handle_push_response(response, disable_decoding, push_request) else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -260,5 +267,20 @@ async def _read_response( response = self.encoder.decode(response) return response - def set_push_handler(self, push_handler_func): - self.push_handler_func = push_handler_func + async def handle_push_response(self, response, disable_decoding, push_request): + if response[0] in _INVALIDATION_MESSAGE: + res = self.invalidation_push_handler_func(response) + else: + res = self.pubsub_push_handler_func(response) + if not push_request: + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + + def set_pubsub_push_handler(self, pubsub_push_handler_func): + self.pubsub_push_handler_func = pubsub_push_handler_func + + def set_invalidation_push_handler(self, invalidations_push_handler_func): + self.invalidation_push_handler_func = invalidations_push_handler_func diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 9e0491f810..eea9612f4a 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -37,6 +37,12 @@ ) from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry +from redis.cache import ( + DEFAULT_BLACKLIST, + DEFAULT_EVICTION_POLICY, + DEFAULT_WHITELIST, + _LocalCache, +) from redis.client import ( EMPTY_RESPONSE, NEVER_DECODE, @@ -60,7 +66,7 @@ TimeoutError, WatchError, ) -from redis.typing import ChannelT, EncodableT, KeyT +from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, @@ -231,6 +237,13 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, + cache_enable: bool = False, + client_cache: Optional[_LocalCache] = None, + cache_max_size: int = 100, + cache_ttl: int = 0, + cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_blacklist: List[str] = DEFAULT_BLACKLIST, + cache_whitelist: List[str] = DEFAULT_WHITELIST, ): """ Initialize a new Redis client. @@ -336,6 +349,16 @@ def __init__( # on a set of redis commands self._single_conn_lock = asyncio.Lock() + self.client_cache = client_cache + if cache_enable: + self.client_cache = _LocalCache( + cache_max_size, cache_ttl, cache_eviction_policy + ) + if self.client_cache is not None: + self.cache_blacklist = cache_blacklist + self.cache_whitelist = cache_whitelist + self.client_cache_initialized = False + def __repr__(self): return ( f"<{self.__class__.__module__}.{self.__class__.__name__}" @@ -350,6 +373,10 @@ async def initialize(self: _RedisT) -> _RedisT: async with self._single_conn_lock: if self.connection is None: self.connection = await self.connection_pool.get_connection("_") + if self.client_cache is not None: + self.connection._parser.set_invalidation_push_handler( + self._cache_invalidation_process + ) return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -568,6 +595,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: close_connection_pool is None and self.auto_close_connection_pool ): await self.connection_pool.disconnect() + if self.client_cache: + self.client_cache.flush() @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") async def close(self, close_connection_pool: Optional[bool] = None) -> None: @@ -596,29 +625,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): ): raise error + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + if data[1] is not None: + for key in data[1]: + self.client_cache.invalidate(str_if_bytes(key)) + else: + self.client_cache.flush() + + async def _get_from_local_cache(self, command: str): + """ + If the command is in the local cache, return the response + """ + if ( + self.client_cache is None + or command[0] in self.cache_blacklist + or command[0] not in self.cache_whitelist + ): + return None + while not self.connection._is_socket_empty(): + await self.connection.read_response(push_request=True) + return self.client_cache.get(command) + + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): + """ + Add the command and response to the local cache if the command + is allowed to be cached + """ + if ( + self.client_cache is not None + and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist) + and (self.cache_whitelist == [] or command[0] in self.cache_whitelist) + ): + self.client_cache.set(command, response, keys) + + def delete_from_local_cache(self, command: str): + """ + Delete the command from the local cache + """ + try: + self.client_cache.delete(command) + except AttributeError: + pass + # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() - options.pop("keys", None) # the keys are used only for client side caching - pool = self.connection_pool command_name = args[0] - conn = self.connection or await pool.get_connection(command_name, **options) + keys = options.pop("keys", None) # keys are used only for client side caching + response_from_cache = await self._get_from_local_cache(args) + if response_from_cache is not None: + return response_from_cache + else: + pool = self.connection_pool + conn = self.connection or await pool.get_connection(command_name, **options) - if self.single_connection_client: - await self._single_conn_lock.acquire() - try: - return await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - finally: if self.single_connection_client: - self._single_conn_lock.release() - if not self.connection: - await pool.release(conn) + await self._single_conn_lock.acquire() + try: + if self.client_cache is not None and not self.client_cache_initialized: + await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, "CLIENT", *("CLIENT", "TRACKING", "ON") + ), + lambda error: self._disconnect_raise(conn, error), + ) + self.client_cache_initialized = True + response = await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + self._add_to_local_cache(args, response, keys) + return response + finally: + if self.single_connection_client: + self._single_conn_lock.release() + if not self.connection: + await pool.release(conn) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options @@ -866,7 +961,7 @@ async def connect(self): else: await self.connection.connect() if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) async def _disconnect_raise_connect(self, conn, error): """ diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index bbd438fc0b..39f75a5f13 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -645,6 +645,10 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] output.append(SYM_EMPTY.join(pieces)) return output + def _is_socket_empty(self): + """Check if the socket is empty""" + return not self._reader.at_eof() + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" diff --git a/redis/cache.py b/redis/cache.py index 5a689d0ebd..d920702339 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -159,7 +159,7 @@ class EvictionPolicy(Enum): RANDOM = "random" -class _LocalChace: +class _LocalCache: """ A caching mechanism for storing redis commands and their responses. @@ -220,6 +220,7 @@ def get(self, command: str) -> ResponseT: if command in self.cache: if self._is_expired(command): self.delete(command) + return self._update_access(command) return self.cache[command]["response"] @@ -266,28 +267,28 @@ def _update_access(self, command: str): Args: command (str): The redis command. """ - if self.eviction_policy == EvictionPolicy.LRU: + if self.eviction_policy == EvictionPolicy.LRU.value: self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.LFU: + elif self.eviction_policy == EvictionPolicy.LFU.value: self.cache[command]["access_count"] = ( self.cache.get(command, {}).get("access_count", 0) + 1 ) self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.RANDOM: + elif self.eviction_policy == EvictionPolicy.RANDOM.value: pass # Random eviction doesn't require updates def _evict(self): """Evict a redis command from the cache based on the eviction policy.""" if self._is_expired(self.commands_ttl_list[0]): self.delete(self.commands_ttl_list[0]) - elif self.eviction_policy == EvictionPolicy.LRU: + elif self.eviction_policy == EvictionPolicy.LRU.value: self.cache.popitem(last=False) - elif self.eviction_policy == EvictionPolicy.LFU: + elif self.eviction_policy == EvictionPolicy.LFU.value: min_access_command = min( self.cache, key=lambda k: self.cache[k].get("access_count", 0) ) self.cache.pop(min_access_command) - elif self.eviction_policy == EvictionPolicy.RANDOM: + elif self.eviction_policy == EvictionPolicy.RANDOM.value: random_command = random.choice(list(self.cache.keys())) self.cache.pop(random_command) @@ -322,5 +323,6 @@ def invalidate(self, key: KeyT): """ if key not in self.key_commands_map: return - for command in self.key_commands_map[key]: + commands = list(self.key_commands_map[key]) + for command in commands: self.delete(command) diff --git a/redis/client.py b/redis/client.py index 0af7e050d6..7f2c8d290d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,7 +4,7 @@ import time import warnings from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( @@ -17,7 +17,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, - _LocalChace, + _LocalCache, ) from redis.commands import ( CoreCommands, @@ -211,7 +211,7 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, cache_enable: bool = False, - client_cache: Optional[_LocalChace] = None, + client_cache: Optional[_LocalCache] = None, cache_max_size: int = 100, cache_ttl: int = 0, cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, @@ -326,12 +326,16 @@ def __init__( self.client_cache = client_cache if cache_enable: - self.client_cache = _LocalChace( + self.client_cache = _LocalCache( cache_max_size, cache_ttl, cache_eviction_policy ) if self.client_cache is not None: self.cache_blacklist = cache_blacklist self.cache_whitelist = cache_whitelist + self.client_tracking_on() + self.connection._parser.set_invalidation_push_handler( + self._cache_invalidation_process + ) def __repr__(self) -> str: return ( @@ -358,6 +362,21 @@ def set_response_callback(self, command: str, callback: Callable) -> None: """Set a custom Response Callback""" self.response_callbacks[command] = callback + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + if data[1] is not None: + for key in data[1]: + self.client_cache.invalidate(str_if_bytes(key)) + else: + self.client_cache.flush() + def load_external_module(self, funcname, func) -> None: """ This function can be used to add externally defined redis modules, @@ -530,6 +549,8 @@ def close(self): if self.auto_close_connection_pool: self.connection_pool.disconnect() + if self.client_cache: + self.client_cache.flush() def _send_command_parse_response(self, conn, command_name, *args, **options): """ @@ -561,9 +582,13 @@ def _get_from_local_cache(self, command: str): or command[0] not in self.cache_whitelist ): return None + while not self.connection._is_socket_empty(): + self.connection.read_response(push_request=True) return self.client_cache.get(command) - def _add_to_local_cache(self, command: str, response: ResponseT, keys: List[KeysT]): + def _add_to_local_cache( + self, command: Tuple[str], response: ResponseT, keys: List[KeysT] + ): """ Add the command and response to the local cache if the command is allowed to be cached @@ -819,7 +844,7 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: diff --git a/redis/cluster.py b/redis/cluster.py index 0405b0547c..8032173e66 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1778,7 +1778,7 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) if self.push_handler_func is not None and not HIREDIS_AVAILABLE: - self.connection._parser.set_push_handler(self.push_handler_func) + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) diff --git a/redis/connection.py b/redis/connection.py index c201224e35..35a4ff4a37 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,5 +1,6 @@ import copy import os +import select import socket import ssl import sys @@ -572,6 +573,11 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output + def _is_socket_empty(self): + """Check if the socket is empty""" + r, _, _ = select.select([self._sock], [], [], 0) + return not bool(r) + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py new file mode 100644 index 0000000000..c837acfed1 --- /dev/null +++ b/tests/test_asyncio/test_cache.py @@ -0,0 +1,129 @@ +import time + +import pytest +import redis.asyncio as redis +from redis.utils import HIREDIS_AVAILABLE + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_get_from_cache(): + r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + r2 = redis.Redis(protocol=3) + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + await r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == b"barbar" + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_max_size(): + r = redis.Redis( + cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 + ) + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) == b"bar2" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_ttl(): + r = redis.Redis( + cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 + ) + # add key to redis + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_lfu_eviction(): + r = redis.Redis( + cache_enable=True, + cache_max_size=3, + cache_eviction_policy="lfu", + single_connection_client=True, + protocol=3, + ) + # add 3 keys to redis + await r.set("foo", "bar") + await r.set("foo2", "bar2") + await r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert await r.get("foo") == b"bar" + assert await r.get("foo2") == b"bar2" + assert await r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + await r.set("foo4", "bar4") + assert await r.get("foo4") == b"bar4" + # test the eviction policy + assert len(r.client_cache.cache) == 3 + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) is None + + await r.aclose() + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_cache_decode_response(): + r = redis.Redis( + decode_responses=True, + cache_enable=True, + single_connection_client=True, + protocol=3, + ) + await r.set("foo", "bar") + # get key from redis and save in local cache + assert await r.get("foo") == "bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + await r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + await r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert await r.get("foo") == "barbar" + + await r.aclose() diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000000..45621fe77e --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,119 @@ +import time + +import pytest +import redis +from redis.utils import HIREDIS_AVAILABLE + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_get_from_cache(): + r = redis.Redis(cache_enable=True, single_connection_client=True, protocol=3) + r2 = redis.Redis(protocol=3) + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == b"barbar" + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_max_size(): + r = redis.Redis( + cache_enable=True, cache_max_size=3, single_connection_client=True, protocol=3 + ) + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) == b"bar2" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_ttl(): + r = redis.Redis( + cache_enable=True, cache_ttl=1, single_connection_client=True, protocol=3 + ) + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_lfu_eviction(): + r = redis.Redis( + cache_enable=True, + cache_max_size=3, + cache_eviction_policy="lfu", + single_connection_client=True, + protocol=3, + ) + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # test the eviction policy + assert len(r.client_cache.cache) == 3 + assert r.client_cache.get(("GET", "foo")) == b"bar" + assert r.client_cache.get(("GET", "foo2")) is None + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +def test_cache_decode_response(): + r = redis.Redis( + decode_responses=True, + cache_enable=True, + single_connection_client=True, + protocol=3, + ) + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == "bar" + # get key from local cache + assert r.client_cache.get(("GET", "foo")) == "bar" + # change key in redis (cause invalidation) + r.set("foo", "barbar") + # send any command to redis (process invalidation in background) + r.ping() + # the command is not in the local cache anymore + assert r.client_cache.get(("GET", "foo")) is None + # get key from redis + assert r.get("foo") == "barbar"