From 156a14bf3f8488834eb3f61fabe1e8e949d59474 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Fri, 30 Aug 2024 16:47:54 -0400 Subject: [PATCH] formatting and linting --- redisvl/index/index.py | 2 +- redisvl/query/__init__.py | 2 +- redisvl/query/query.py | 15 ++++++++------- tests/integration/test_llmcache.py | 4 +++- tests/integration/test_session_manager.py | 3 +-- tests/unit/test_query_types.py | 8 ++++---- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 4b2645a4..d5aeb5e2 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -65,7 +65,7 @@ def process_results( unpack_json = ( (storage_type == StorageType.JSON) and isinstance(query, FilterQuery) - and not query._return_fields + and not query._return_fields # type: ignore ) # Process records diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index ecae6bad..8246794f 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -13,5 +13,5 @@ "FilterQuery", "RangeQuery", "VectorRangeQuery", - "CountQuery" + "CountQuery", ] diff --git a/redisvl/query/query.py b/redisvl/query/query.py index e111046a..9ba05481 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -28,7 +28,7 @@ def __str__(self) -> str: def _build_query_string(self) -> str: """Build the full Redis query string.""" - pass + raise NotImplementedError("Must be implemented by subclasses") def set_filter(self, filter_expression: Optional[FilterExpression] = None): """Set the filter expression for the query. @@ -45,12 +45,13 @@ def set_filter(self, filter_expression: Optional[FilterExpression] = None): elif isinstance(filter_expression, FilterExpression): self._filter_expression = filter_expression else: - raise TypeError("filter_expression must be of type FilterExpression or None") + raise TypeError( + "filter_expression must be of type FilterExpression or None" + ) # Reset the query string self._query_string = self._build_query_string() - @property def filter(self) -> FilterExpression: """The filter expression for the query.""" @@ -167,8 +168,8 @@ def _build_query_string(self) -> str: return str(self._filter_expression) -class BaseVectorQuery(BaseQuery): - DTYPES: Dict[str, np.dtype] = { +class BaseVectorQuery: + DTYPES: Dict[str, Any] = { "float32": np.float32, "float64": np.float64, } @@ -176,7 +177,7 @@ class BaseVectorQuery(BaseQuery): VECTOR_PARAM: str = "vector" -class VectorQuery(BaseVectorQuery): +class VectorQuery(BaseVectorQuery, BaseQuery): def __init__( self, vector: Union[List[float], bytes], @@ -268,7 +269,7 @@ def params(self) -> Dict[str, Any]: return {self.VECTOR_PARAM: vector} -class VectorRangeQuery(BaseVectorQuery): +class VectorRangeQuery(BaseVectorQuery, BaseQuery): DISTANCE_THRESHOLD_PARAM: str = "distance_threshold" def __init__( diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 9252f2bd..f1c6e262 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -19,7 +19,9 @@ def vectorizer(): @pytest.fixture def cache(vectorizer, redis_url): cache_instance = SemanticCache( - vectorizer=vectorizer, distance_threshold=0.2, redis_url="redis://localhost:6379" + vectorizer=vectorizer, + distance_threshold=0.2, + redis_url="redis://localhost:6379", ) yield cache_instance cache_instance._index.delete(True) # Clean up index diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 56943447..20c2955d 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -464,8 +464,7 @@ def test_semantic_add_and_get_relevant(semantic_session): default_context = semantic_session.get_relevant("list of fruits and vegetables") assert len(default_context) == 5 # 2 pairs of prompt:response, and system assert default_context == semantic_session.get_relevant( - "list of fruits and vegetables", - distance_threshold=0.5 + "list of fruits and vegetables", distance_threshold=0.5 ) # test tool calls can also be returned diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 00081ea5..17426868 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.search.query import Query from redis.commands.search.result import Result @@ -124,10 +123,11 @@ def test_vector_query(): ["field1", "field2"], dialect=3, num_results=10, - in_order=True + in_order=True, ) assert vector_query._in_order + def test_range_query(): # Create a filter expression filter_expression = Tag("brand") == "Nike" @@ -182,7 +182,7 @@ def test_range_query(): ["field1"], filter_expression, num_results=10, - in_order=True + in_order=True, ) assert range_query._in_order @@ -193,7 +193,7 @@ def test_range_query(): CountQuery(), FilterQuery(), VectorQuery(vector=[1, 2, 3], vector_field_name="vector"), - RangeQuery(vector=[1, 2, 3], vector_field_name="vector") + RangeQuery(vector=[1, 2, 3], vector_field_name="vector"), ], ) def test_query_modifiers(query):