diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 470a29e..262e760 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.7 + python: python3.8 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.2.0 @@ -12,7 +12,7 @@ repos: - id: trailing-whitespace exclude: README.md - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort name: isort (python) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 3e37ce5..161848f 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -3,7 +3,7 @@ import typing import uuid from decimal import Decimal -from typing import Any, Optional, Union, cast, Dict, TypeVar +from typing import Any, Dict, Optional, TypeVar, Union, cast from sqlalchemy import types as sqa_types from sqlalchemy.dialects import postgresql @@ -101,12 +101,12 @@ def is_column_nullable(column): def convert_sqlalchemy_relationship( - relationship_prop, - obj_type, - connection_field_factory, - batching, - orm_field_name, - **field_kwargs, + relationship_prop, + obj_type, + connection_field_factory, + batching, + orm_field_name, + **field_kwargs, ): """ :param sqlalchemy.RelationshipProperty relationship_prop: @@ -147,7 +147,7 @@ def dynamic_type(): def _convert_o2o_or_m2o_relationship( - relationship_prop, obj_type, batching, orm_field_name, **field_kwargs + relationship_prop, obj_type, batching, orm_field_name, **field_kwargs ): """ Convert one-to-one or many-to-one relationshsip. Return an object field. @@ -175,7 +175,7 @@ def _convert_o2o_or_m2o_relationship( def _convert_o2m_or_m2m_relationship( - relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs + relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs ): """ Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. @@ -281,11 +281,11 @@ def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): @singledispatchbymatchfunction def convert_sqlalchemy_type( # noqa - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - replace_type_vars: typing.Dict[str, Any] = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + replace_type_vars: typing.Dict[str, Any] = None, + **kwargs, ): if replace_type_vars and type_arg in replace_type_vars: return replace_type_vars[type_arg] @@ -301,7 +301,7 @@ def convert_sqlalchemy_type( # noqa @convert_sqlalchemy_type.register(safe_isinstance(DeclarativeMeta)) def convert_sqlalchemy_model_using_registry( - type_arg: Any, registry: Registry = None, **kwargs + type_arg: Any, registry: Registry = None, **kwargs ): registry_ = registry or get_global_registry() @@ -353,8 +353,8 @@ def convert_column_to_string(type_arg: Any, **kwargs): @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.UUIDType)) @convert_sqlalchemy_type.register(column_type_eq(uuid.UUID)) def convert_column_to_uuid( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.UUID @@ -362,8 +362,8 @@ def convert_column_to_uuid( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.DateTime)) @convert_sqlalchemy_type.register(column_type_eq(datetime.datetime)) def convert_column_to_datetime( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.DateTime @@ -371,8 +371,8 @@ def convert_column_to_datetime( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Time)) @convert_sqlalchemy_type.register(column_type_eq(datetime.time)) def convert_column_to_time( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Time @@ -380,8 +380,8 @@ def convert_column_to_time( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Date)) @convert_sqlalchemy_type.register(column_type_eq(datetime.date)) def convert_column_to_date( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Date @@ -390,10 +390,10 @@ def convert_column_to_date( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Integer)) @convert_sqlalchemy_type.register(column_type_eq(int)) def convert_column_to_int_or_id( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): # fixme drop the primary key processing from here in another pr if column is not None: @@ -405,8 +405,8 @@ def convert_column_to_int_or_id( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Boolean)) @convert_sqlalchemy_type.register(column_type_eq(bool)) def convert_column_to_boolean( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Boolean @@ -416,8 +416,8 @@ def convert_column_to_boolean( @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Numeric)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.BigInteger)) def convert_column_to_float( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.Float @@ -425,10 +425,10 @@ def convert_column_to_float( @convert_sqlalchemy_type.register(column_type_eq(postgresql.ENUM)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Enum)) def convert_enum_to_enum( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("SQL-Enum conversion requires a column") @@ -439,9 +439,9 @@ def convert_enum_to_enum( # TODO Make ChoiceType conversion consistent with other enums @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ChoiceType)) def convert_choice_to_enum( - type_arg: sqa_utils.ChoiceType, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - **kwargs, + type_arg: sqa_utils.ChoiceType, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("ChoiceType conversion requires a column") @@ -457,8 +457,8 @@ def convert_choice_to_enum( @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.ScalarListType)) def convert_scalar_list_to_list( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return graphene.List(graphene.String) @@ -474,10 +474,10 @@ def init_array_list_recursive(inner_type, n): @convert_sqlalchemy_type.register(column_type_eq(sqa_types.ARRAY)) @convert_sqlalchemy_type.register(column_type_eq(postgresql.ARRAY)) def convert_array_to_list( - type_arg: Any, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: Any, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("SQL-Array conversion requires a column") @@ -496,8 +496,8 @@ def convert_array_to_list( @convert_sqlalchemy_type.register(column_type_eq(postgresql.JSON)) @convert_sqlalchemy_type.register(column_type_eq(postgresql.JSONB)) def convert_json_to_string( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return JSONString @@ -505,18 +505,18 @@ def convert_json_to_string( @convert_sqlalchemy_type.register(column_type_eq(sqa_utils.JSONType)) @convert_sqlalchemy_type.register(column_type_eq(sqa_types.JSON)) def convert_json_type_to_string( - type_arg: Any, - **kwargs, + type_arg: Any, + **kwargs, ): return JSONString @convert_sqlalchemy_type.register(column_type_eq(sqa_types.Variant)) def convert_variant_to_impl_type( - type_arg: sqa_types.Variant, - column: Optional[Union[MapperProperty, hybrid_property]] = None, - registry: Registry = None, - **kwargs, + type_arg: sqa_types.Variant, + column: Optional[Union[MapperProperty, hybrid_property]] = None, + registry: Registry = None, + **kwargs, ): if column is None or isinstance(column, hybrid_property): raise Exception("Vaiant conversion requires a column") @@ -547,7 +547,7 @@ def is_union(type_arg: Any, **kwargs) -> bool: def graphene_union_for_py_union( - obj_types: typing.List[graphene.ObjectType], registry + obj_types: typing.List[graphene.ObjectType], registry ) -> graphene.Union: union_type = registry.get_union_for_object_types(obj_types) @@ -590,8 +590,8 @@ def convert_sqlalchemy_hybrid_property_union(type_arg: Any, **kwargs): # Now check if every type is instance of an ObjectType if not all( - isinstance(graphene_type, type(graphene.ObjectType)) - for graphene_type in graphene_types + isinstance(graphene_type, type(graphene.ObjectType)) + for graphene_type in graphene_types ): raise ValueError( "Cannot convert hybrid_property Union to graphene.Union: the Union contains scalars. " diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index 5a4e684..bb42272 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -16,6 +16,7 @@ "BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer ) + class SQLAlchemyFilterInputField(graphene.InputField): def __init__( self, @@ -42,6 +43,7 @@ def __init__( self.model_attr = model_attr + def _get_functions_by_regex( regex: str, subtract_regex: str, class_: Type ) -> List[Tuple[str, Dict[str, Any]]]: @@ -305,11 +307,13 @@ def execute_filters( return query, clauses + class SQLEnumFilter(FieldFilter): """Basic Filter for Scalars in Graphene. We want this filter to use Dynamic fields so it provides the base filtering methods ("eq, nEq") for different types of scalars. The Dynamic fields will resolve to Meta.filtered_type""" + class Meta: graphene_type = graphene.Enum @@ -326,11 +330,13 @@ def n_eq_filter( ) -> Union[Tuple[Query, Any], Any]: return not_(field == val.value) + class PyEnumFilter(FieldFilter): """Basic Filter for Scalars in Graphene. We want this filter to use Dynamic fields so it provides the base filtering methods ("eq, nEq") for different types of scalars. The Dynamic fields will resolve to Meta.filtered_type""" + class Meta: graphene_type = graphene.Enum diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 2a45e78..b959d22 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -2,16 +2,14 @@ from collections import defaultdict from typing import TYPE_CHECKING, List, Type +from sqlalchemy.types import Enum as SQLAlchemyEnumType + import graphene from graphene import Enum from graphene.types.base import BaseType -from sqlalchemy.types import Enum as SQLAlchemyEnumType if TYPE_CHECKING: # pragma: no_cover - from .filters import ( - FieldFilter, - BaseTypeFilter, - RelationshipFilter, ) + from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter class Registry(object): @@ -32,14 +30,15 @@ def __init__(self): def _init_base_filters(self): import graphene_sqlalchemy.filters as gsqa_filters - from .filters import (FieldFilter) + from .filters import FieldFilter + field_filter_classes = [ filter_cls[1] for filter_cls in inspect.getmembers(gsqa_filters, inspect.isclass) if ( - filter_cls[1] is not FieldFilter - and FieldFilter in filter_cls[1].__mro__ - and getattr(filter_cls[1]._meta, "graphene_type", False) + filter_cls[1] is not FieldFilter + and FieldFilter in filter_cls[1].__mro__ + and getattr(filter_cls[1]._meta, "graphene_type", False) ) ] for field_filter_class in field_filter_classes: @@ -100,7 +99,7 @@ def register_sort_enum(self, obj_type, sort_enum: Enum): from .types import SQLAlchemyObjectType if not isinstance(obj_type, type) or not issubclass( - obj_type, SQLAlchemyObjectType + obj_type, SQLAlchemyObjectType ): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) @@ -113,7 +112,7 @@ def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType): return self._registry_sort_enums.get(obj_type) def register_union_type( - self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] + self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]] ): if not issubclass(union, graphene.Union): raise TypeError("Expected graphene.Union, but got: {!r}".format(union)) @@ -131,7 +130,7 @@ def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]]) # Filter Scalar Fields of Object Types def register_filter_for_scalar_type( - self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] + self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"] ): from .filters import FieldFilter @@ -143,7 +142,7 @@ def register_filter_for_scalar_type( self._registry_scalar_filters[scalar_type] = filter_obj def get_filter_for_sql_enum_type( - self, enum_type: Type[graphene.Enum] + self, enum_type: Type[graphene.Enum] ) -> Type["FieldFilter"]: from .filters import SQLEnumFilter @@ -156,7 +155,7 @@ def get_filter_for_sql_enum_type( return filter_type def get_filter_for_py_enum_type( - self, enum_type: Type[graphene.Enum] + self, enum_type: Type[graphene.Enum] ) -> Type["FieldFilter"]: from .filters import PyEnumFilter @@ -169,7 +168,7 @@ def get_filter_for_py_enum_type( return filter_type def get_filter_for_scalar_type( - self, scalar_type: Type[graphene.Scalar] + self, scalar_type: Type[graphene.Scalar] ) -> Type["FieldFilter"]: from .filters import FieldFilter @@ -184,7 +183,7 @@ def get_filter_for_scalar_type( # TODO register enums automatically def register_filter_for_enum_type( - self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] + self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] ): from .filters import FieldFilter @@ -197,9 +196,9 @@ def register_filter_for_enum_type( # Filter Base Types def register_filter_for_base_type( - self, - base_type: Type[BaseType], - filter_obj: Type["BaseTypeFilter"], + self, + base_type: Type[BaseType], + filter_obj: Type["BaseTypeFilter"], ): from .filters import BaseTypeFilter @@ -207,9 +206,7 @@ def register_filter_for_base_type( raise TypeError("Expected BaseType, but got: {!r}".format(base_type)) if not issubclass(filter_obj, BaseTypeFilter): - raise TypeError( - "Expected BaseTypeFilter, but got: {!r}".format(filter_obj) - ) + raise TypeError("Expected BaseTypeFilter, but got: {!r}".format(filter_obj)) self._registry_base_type_filters[base_type] = filter_obj def get_filter_for_base_type(self, base_type: Type[BaseType]): @@ -217,7 +214,7 @@ def get_filter_for_base_type(self, base_type: Type[BaseType]): # Filter Relationships between base types def register_relationship_filter_for_base_type( - self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] + self, base_type: BaseType, filter_obj: Type["RelationshipFilter"] ): from .filters import RelationshipFilter @@ -231,7 +228,7 @@ def register_relationship_filter_for_base_type( self._registry_relationship_filters[base_type] = filter_obj def get_relationship_filter_for_base_type( - self, base_type: Type[BaseType] + self, base_type: Type[BaseType] ) -> "RelationshipFilter": return self._registry_relationship_filters.get(base_type) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py index 9489011..2c749da 100644 --- a/graphene_sqlalchemy/tests/conftest.py +++ b/graphene_sqlalchemy/tests/conftest.py @@ -1,15 +1,15 @@ -from typing_extensions import Literal - -import graphene import pytest import pytest_asyncio from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from typing_extensions import Literal +import graphene from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4 -from .models import Base, CompositeFullName + from ..converter import convert_sqlalchemy_composite from ..registry import reset_global_registry +from .models import Base, CompositeFullName if SQL_VERSION_HIGHER_EQUAL_THAN_1_4: from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 7ec6de3..12554dc 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -6,6 +6,7 @@ from decimal import Decimal from typing import List, Optional +# fmt: off from sqlalchemy import ( Column, Date, @@ -23,14 +24,15 @@ from sqlalchemy.sql.type_api import TypeEngine from graphene_sqlalchemy.tests.utils import wrap_select_func -from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 +from graphene_sqlalchemy.utils import ( + SQL_VERSION_HIGHER_EQUAL_THAN_1_4, + SQL_VERSION_HIGHER_EQUAL_THAN_2, +) -# fmt: off -import sqlalchemy if SQL_VERSION_HIGHER_EQUAL_THAN_2: - from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip else: - from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip # fmt: on PetKind = Enum("cat", "dog", name="pet_kind") diff --git a/graphene_sqlalchemy/tests/models_batching.py b/graphene_sqlalchemy/tests/models_batching.py index 5dde366..e0f5d4b 100644 --- a/graphene_sqlalchemy/tests/models_batching.py +++ b/graphene_sqlalchemy/tests/models_batching.py @@ -2,16 +2,7 @@ import enum -from sqlalchemy import ( - Column, - Date, - Enum, - ForeignKey, - Integer, - String, - Table, - func, -) +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table, func from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import column_property, relationship diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 60bf505..1b2e0ec 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,13 +1,10 @@ import enum import sys -from typing import Dict, Tuple, Union, TypeVar +from typing import Dict, Tuple, TypeVar, Union -import graphene import pytest import sqlalchemy import sqlalchemy_utils as sqa_utils -from graphene.relay import Node -from graphene.types.structures import Structure from sqlalchemy import Column, func, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base @@ -15,15 +12,10 @@ from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from .models import ( - Article, - CompositeFullName, - Pet, - Reporter, - ShoppingCart, - ShoppingCartItem, -) -from .utils import wrap_select_func +import graphene +from graphene.relay import Node +from graphene.types.structures import Structure + from ..converter import ( convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -45,6 +37,7 @@ ShoppingCart, ShoppingCartItem, ) +from .utils import wrap_select_func def mock_resolver(): @@ -210,12 +203,11 @@ def test_converter_replace_type_var(): replace_type_vars = {T: graphene.String} - field_type = convert_sqlalchemy_type( - T, replace_type_vars=replace_type_vars - ) + field_type = convert_sqlalchemy_type(T, replace_type_vars=replace_type_vars) assert field_type == graphene.String + @pytest.mark.skipif( sys.version_info < (3, 10), reason="|-Style Unions are unsupported in python < 3.10" ) @@ -225,9 +217,9 @@ def prop_method() -> int | str: return "not allowed in gql schema" with pytest.raises( - ValueError, - match=r"Cannot convert hybrid_property Union to " - r"graphene.Union: the Union contains scalars. \.*", + ValueError, + match=r"Cannot convert hybrid_property Union to " + r"graphene.Union: the Union contains scalars. \.*", ): get_hybrid_property_type(prop_method) @@ -481,7 +473,9 @@ class TestEnum(enum.IntEnum): def test_should_columproperty_convert(): field = get_field_from_column( - column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1)) + column_property( + wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1) + ) ) assert field.type == graphene.Int @@ -840,8 +834,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_item_expected_types.items(): hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name] @@ -852,7 +846,7 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property ################################################### @@ -900,8 +894,8 @@ class Meta: ) for ( - hybrid_prop_name, - hybrid_prop_expected_return_type, + hybrid_prop_name, + hybrid_prop_expected_return_type, ) in shopping_cart_expected_types.items(): hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name] @@ -912,5 +906,5 @@ class Meta: str(hybrid_prop_expected_return_type), ) assert ( - hybrid_prop_field.description is None + hybrid_prop_field.description is None ) # "doc" is ignored by hybrid property diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 026247c..4acf89a 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -1,8 +1,12 @@ -import graphene import pytest -from graphene import Connection, relay from sqlalchemy.sql.operators import is_ +import graphene +from graphene import Connection, relay + +from ..fields import SQLAlchemyConnectionField +from ..filters import FloatFilter +from ..types import ORMField, SQLAlchemyObjectType from .models import ( Article, Editor, @@ -16,10 +20,6 @@ Tag, ) from .utils import eventually_await_session, to_std_dicts -from ..fields import SQLAlchemyConnectionField -from ..filters import FloatFilter -from ..types import ORMField, SQLAlchemyObjectType - # TODO test that generated schema is correct for all examples with: # with open('schema.gql', 'w') as fp: @@ -257,7 +257,17 @@ async def test_filter_enum(session): } """ expected = { - "reporters": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}}]}, + "reporters": { + "edges": [ + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + } + ] + }, } schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) @@ -916,8 +926,20 @@ async def test_filter_logic_or(session): expected = { "reporters": { "edges": [ - {"node": {"firstName": "John", "lastName": "Woe", "favoritePetKind": "CAT"}}, - {"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}}, + { + "node": { + "firstName": "John", + "lastName": "Woe", + "favoritePetKind": "CAT", + } + }, + { + "node": { + "firstName": "Jane", + "lastName": "Roe", + "favoritePetKind": "DOG", + } + }, ] } } @@ -1109,7 +1131,7 @@ async def test_filter_hybrid_property(session): result = to_std_dicts(result.data) assert len(result["carts"]["edges"]) == 1 assert ( - len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 + len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2 ) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index b99f023..6b5ab4d 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -6,9 +6,13 @@ from inspect import isawaitable from typing import Any, Optional, Type, Union -import graphene import sqlalchemy -from graphene import Field, InputField, Dynamic +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty +from sqlalchemy.orm.exc import NoResultFound + +import graphene +from graphene import Dynamic, Field, InputField from graphene.relay import Connection, Node from graphene.types.base import BaseType from graphene.types.interface import Interface, InterfaceOptions @@ -16,9 +20,6 @@ from graphene.types.unmountedtype import UnmountedType from graphene.types.utils import yank_fields_from_attrs from graphene.utils.orderedtype import OrderedType -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import ColumnProperty, CompositeProperty, RelationshipProperty -from sqlalchemy.orm.exc import NoResultFound from .converter import ( convert_sqlalchemy_column, @@ -31,7 +32,7 @@ sort_argument_for_object_type, sort_enum_for_object_type, ) -from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter, SQLAlchemyFilterInputField +from .filters import BaseTypeFilter, RelationshipFilter, SQLAlchemyFilterInputField from .registry import Registry, get_global_registry from .resolvers import get_attr_resolver, get_custom_resolver from .utils import ( @@ -51,17 +52,17 @@ class ORMField(OrderedType): def __init__( - self, - model_attr=None, - type_=None, - required=None, - description=None, - deprecation_reason=None, - batching=None, - create_filter=None, - filter_type: Optional[Type] = None, - _creation_counter=None, - **field_kwargs, + self, + model_attr=None, + type_=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + create_filter=None, + filter_type: Optional[Type] = None, + _creation_counter=None, + **field_kwargs, ): """ Use this to override fields automatically generated by SQLAlchemyObjectType. @@ -127,7 +128,7 @@ class Meta: def get_or_create_relationship_filter( - base_type: Type[BaseType], registry: Registry + base_type: Type[BaseType], registry: Registry ) -> Type[RelationshipFilter]: relationship_filter = registry.get_relationship_filter_for_base_type(base_type) @@ -146,11 +147,11 @@ def get_or_create_relationship_filter( def filter_field_from_field( - field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], - type_, - registry: Registry, - model_attr: Any, - model_attr_name: str + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + type_, + registry: Registry, + model_attr: Any, + model_attr_name: str, ) -> Optional[graphene.InputField]: # Field might be a SQLAlchemyObjectType, due to hybrid properties if issubclass(type_, SQLAlchemyObjectType): @@ -174,9 +175,7 @@ def filter_field_from_field( def resolve_dynamic_relationship_filter( - field: graphene.Dynamic, - registry: Registry, - model_attr_name: str + field: graphene.Dynamic, registry: Registry, model_attr_name: str ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # Resolve Dynamic Type type_ = get_nullable_type(field.get_type()) @@ -208,11 +207,11 @@ def resolve_dynamic_relationship_filter( def filter_field_from_type_field( - field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], - registry: Registry, - filter_type: Optional[Type], - model_attr: Any, - model_attr_name: str + field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]], + registry: Registry, + filter_type: Optional[Type], + model_attr: Any, + model_attr_name: str, ) -> Optional[Union[graphene.InputField, graphene.Dynamic]]: # If a custom filter type was set for this field, use it here if filter_type: @@ -223,7 +222,11 @@ def filter_field_from_type_field( # If the generated field is Dynamic, it is always a relationship # (due to graphene-sqlalchemy's conversion mechanism). elif isinstance(field, graphene.Dynamic): - return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr_name)) + return Dynamic( + partial( + resolve_dynamic_relationship_filter, field, registry, model_attr_name + ) + ) # Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List): # Pure lists are not yet supported @@ -234,9 +237,23 @@ def filter_field_from_type_field( # Order matters, this comes last as field._type == list also matches Field elif isinstance(field, graphene.Field): if inspect.isfunction(field._type) or isinstance(field._type, partial): - return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name)) + return Dynamic( + lambda: filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) + ) else: - return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name) + return filter_field_from_field( + field, + get_nullable_type(field.type), + registry, + model_attr, + model_attr_name, + ) def get_polymorphic_on(model): @@ -252,14 +269,14 @@ def get_polymorphic_on(model): def construct_fields_and_filters( - obj_type, - model, - registry, - only_fields, - exclude_fields, - batching, - create_filters, - connection_field_factory, + obj_type, + model, + registry, + only_fields, + exclude_fields, + batching, + create_filters, + connection_field_factory, ): """ Construct all the fields for a SQLAlchemyObjectType. @@ -296,9 +313,9 @@ def construct_fields_and_filters( auto_orm_field_names = [] for attr_name, attr in all_model_attrs.items(): if ( - (only_fields and attr_name not in only_fields) - or (attr_name in exclude_fields) - or attr_name == polymorphic_on + (only_fields and attr_name not in only_fields) + or (attr_name in exclude_fields) + or attr_name == polymorphic_on ): continue auto_orm_field_names.append(attr_name) @@ -390,21 +407,21 @@ class SQLAlchemyBase(BaseType): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - exclude_fields=(), - connection=None, - connection_class=None, - use_connection=None, - interfaces=(), - id=None, - batching=False, - connection_field_factory=None, - _meta=None, - **options, + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + batching=False, + connection_field_factory=None, + _meta=None, + **options, ): # We always want to bypass this hook unless we're defining a concrete # `SQLAlchemyObjectType` or `SQLAlchemyInterface`. @@ -528,6 +545,7 @@ def get_node(cls, info, id): session = get_session(info.context) if isinstance(session, AsyncSession): + async def get_result() -> Any: return await session.get(cls._meta.model, id) @@ -659,7 +677,7 @@ def __init_subclass_with_meta__(cls, _meta=None, **options): if hasattr(_meta.model, "__mapper__"): polymorphic_identity = _meta.model.__mapper__.polymorphic_identity assert ( - polymorphic_identity is None + polymorphic_identity is None ), '{}: An interface cannot map to a concrete type (polymorphic_identity is "{}")'.format( cls.__name__, polymorphic_identity )