Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cleanup logic in graphql_ws protocol handler #3778

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3574,23 +3574,23 @@ the attributes of the `IgnoreContext` class.
For example, the following query:
```python
"""
query {
matt: user(name: "matt") {
email
}
andy: user(name: "andy") {
email
address {
city
}
pets {
name
owner {
name
}
}
query {
matt: user(name: "matt") {
email
}
andy: user(name: "andy") {
email
address {
city
}
pets {
name
owner {
name
}
}
}
}
"""
```
can have its depth limited by the following `should_ignore`:
Expand All @@ -3607,17 +3607,17 @@ query_depth_limiter = QueryDepthLimiter(should_ignore=should_ignore)
so that it *effectively* becomes:
```python
"""
query {
andy: user(name: "andy") {
email
pets {
name
owner {
name
}
}
query {
andy: user(name: "andy") {
email
pets {
name
owner {
name
}
}
}
}
"""
```

Expand Down Expand Up @@ -10702,7 +10702,7 @@ from typing import Annotated
class Query:
@strawberry.field
def user_by_id(
id: Annotated[str, strawberry.argument(description="The ID of the user")]
id: Annotated[str, strawberry.argument(description="The ID of the user")],
) -> User: ...
```

Expand Down
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This release fixes the issue that some subscription resolvers were not canceled if a client unexpectedly disconnected.
2 changes: 1 addition & 1 deletion docs/types/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def name(
str,
strawberry.argument(description="This is a description"),
strawberry.argument(description="Another description"),
]
],
) -> str:
return "Name"

Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_pydantic(session: Session, pydantic: str, gql_core: str) -> None:

@session(python=PYTHON_VERSIONS, name="Type checkers tests", tags=["tests"])
def tests_typecheckers(session: Session) -> None:
session.run_always("poetry", "install", external=True)
session.run_always("poetry", "install", "--with", "integrations", external=True)

session.install("pyright")
session.install("pydantic")
Expand Down
100 changes: 85 additions & 15 deletions strawberry/federation/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,59 @@
from strawberry.types.unset import UNSET

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing_extensions import Literal

from strawberry.extensions.field_extension import FieldExtension
from strawberry.permission import BasePermission
from strawberry.types.field import _RESOLVER_TYPE, StrawberryField
from strawberry.types.field import (
_RESOLVER_TYPE,
_RESOLVER_TYPE_ASYNC,
_RESOLVER_TYPE_SYNC,
StrawberryField,
)

from .schema_directives import Override

T = TypeVar("T")

# NOTE: we are separating the sync and async resolvers because using both
# in the same function will cause mypy to raise an error. Not sure if it is a bug


@overload
def field(
*,
resolver: _RESOLVER_TYPE_ASYNC[T],
name: Optional[str] = None,
is_subscription: bool = False,
description: Optional[str] = None,
authenticated: bool = False,
external: bool = False,
inaccessible: bool = False,
policy: Optional[list[list[str]]] = None,
provides: Optional[list[str]] = None,
override: Optional[Union[Override, str]] = None,
requires: Optional[list[str]] = None,
requires_scopes: Optional[list[list[str]]] = None,
tags: Optional[Iterable[str]] = (),
shareable: bool = False,
init: Literal[False] = False,
permission_classes: Optional[list[type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: Optional[list[FieldExtension]] = None,
graphql_type: Optional[Any] = None,
) -> T: ...


@overload
def field(
*,
resolver: _RESOLVER_TYPE[T],
resolver: _RESOLVER_TYPE_SYNC[T],
name: Optional[str] = None,
is_subscription: bool = False,
description: Optional[str] = None,
Expand All @@ -47,9 +84,10 @@ def field(
init: Literal[False] = False,
permission_classes: Optional[list[type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = UNSET,
default_factory: Union[Callable[..., object], object] = UNSET,
directives: Sequence[object] = (),
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: Optional[list[FieldExtension]] = None,
graphql_type: Optional[Any] = None,
) -> T: ...
Expand All @@ -74,17 +112,18 @@ def field(
init: Literal[True] = True,
permission_classes: Optional[list[type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = UNSET,
default_factory: Union[Callable[..., object], object] = UNSET,
directives: Sequence[object] = (),
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: Optional[list[FieldExtension]] = None,
graphql_type: Optional[Any] = None,
) -> Any: ...


@overload
def field(
resolver: _RESOLVER_TYPE[T],
resolver: _RESOLVER_TYPE_ASYNC[T],
*,
name: Optional[str] = None,
is_subscription: bool = False,
Expand All @@ -101,9 +140,38 @@ def field(
shareable: bool = False,
permission_classes: Optional[list[type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = UNSET,
default_factory: Union[Callable[..., object], object] = UNSET,
directives: Sequence[object] = (),
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: Optional[list[FieldExtension]] = None,
graphql_type: Optional[Any] = None,
) -> StrawberryField: ...


@overload
def field(
resolver: _RESOLVER_TYPE_SYNC[T],
*,
name: Optional[str] = None,
is_subscription: bool = False,
description: Optional[str] = None,
authenticated: bool = False,
external: bool = False,
inaccessible: bool = False,
policy: Optional[list[list[str]]] = None,
provides: Optional[list[str]] = None,
override: Optional[Union[Override, str]] = None,
requires: Optional[list[str]] = None,
requires_scopes: Optional[list[list[str]]] = None,
tags: Optional[Iterable[str]] = (),
shareable: bool = False,
permission_classes: Optional[list[type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: Optional[list[FieldExtension]] = None,
graphql_type: Optional[Any] = None,
) -> StrawberryField: ...
Expand All @@ -129,7 +197,8 @@ def field(
deprecation_reason: Optional[str] = None,
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
directives: Sequence[object] = (),
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: Optional[list[FieldExtension]] = None,
graphql_type: Optional[Any] = None,
# This init parameter is used by PyRight to determine whether this field
Expand All @@ -150,7 +219,7 @@ def field(
Tag,
)

directives = list(directives)
directives = list(directives or [])

if authenticated:
directives.append(Authenticated())
Expand Down Expand Up @@ -197,6 +266,7 @@ def field(
default_factory=default_factory,
init=init, # type: ignore
directives=directives,
metadata=metadata,
extensions=extensions,
graphql_type=graphql_type,
)
Expand Down
3 changes: 2 additions & 1 deletion strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def connection(
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: list[FieldExtension] = (), # type: ignore
extensions: list[FieldExtension] | None = None,
max_results: Optional[int] = None,
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
Expand Down Expand Up @@ -473,6 +473,7 @@ def get_some_nodes(self, age: int) -> Iterable[SomeType]: ...
https://relay.dev/graphql/connections.htm

"""
extensions = extensions or []
f = StrawberryField(
python_name=None,
graphql_name=name,
Expand Down
7 changes: 5 additions & 2 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ async def handle(self) -> None:
with suppress(BaseException):
await self.keep_alive_task

for operation_id in list(self.subscriptions.keys()):
await self.cleanup_operation(operation_id)
await self.cleanup()

async def handle_message(
self,
Expand Down Expand Up @@ -202,6 +201,10 @@ async def cleanup_operation(self, operation_id: str) -> None:
await self.tasks[operation_id]
del self.tasks[operation_id]

async def cleanup(self) -> None:
for operation_id in list(self.tasks.keys()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: is it ok to just do this for self.tasks and not for self.subscriptions anymore? Wondering if the correct thing would be to do:

Suggested change
for operation_id in list(self.tasks.keys()):
for operation_id in set(self.tasks) | set(self.subscriptions):

Copy link
Contributor Author

@jakub-bacic jakub-bacic Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can see, each subscription has a matching task with the same operation_id. However, the opposite isn't true - not every task has a matching subscription entry (that was the root cause of this issue and why a simple switch from self.subscriptions to self.tasks fixed that).

Task is added in handle_start:

        result_handler = self.handle_async_results(
            operation_id, query, operation_name, variables
        )
        self.tasks[operation_id] = asyncio.create_task(result_handler)

Meanwhile, the subscription entry is added later in handle_async_results, only after receiving the first value from the asynchronous generator:

            agen_or_err = await self.schema.subscribe(
                query=query,
                variable_values=variables,
                operation_name=operation_name,
                context_value=self.context,
                root_value=self.root_value,
            )
            if isinstance(agen_or_err, PreExecutionError):
                assert agen_or_err.errors
                await self.send_message(
                    {
                        "type": "error",
                        "id": operation_id,
                        "payload": agen_or_err.errors[0].formatted,
                    }
                )
            else:
                self.subscriptions[operation_id] = agen_or_err

That's why if cleanup is executed before the subscription resolver yields any value, the task will have no matching entry in self.subscriptions yet.

Additionally, it's safe to call cleanup_operation for each task because it checks internally if the given operation exists in self.subscriptions before attempting to delete it.

await self.cleanup_operation(operation_id)

async def send_data_message(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
Expand Down
13 changes: 10 additions & 3 deletions tests/views/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def match_text(self, text_file: Upload, pattern: str) -> str:

@strawberry.type
class Subscription:
active_infinity_subscriptions = 0

@strawberry.subscription
async def echo(self, message: str, delay: float = 0) -> AsyncGenerator[str, None]:
await asyncio.sleep(delay)
Expand All @@ -164,9 +166,14 @@ async def request_ping(self, info: strawberry.Info) -> AsyncGenerator[bool, None

@strawberry.subscription
async def infinity(self, message: str) -> AsyncGenerator[str, None]:
while True:
yield message
await asyncio.sleep(1)
Subscription.active_infinity_subscriptions += 1

try:
while True:
yield message
await asyncio.sleep(1)
finally:
Subscription.active_infinity_subscriptions -= 1

@strawberry.subscription
async def context(self, info: strawberry.Info) -> AsyncGenerator[str, None]:
Expand Down
10 changes: 6 additions & 4 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
SubscribeMessage,
)
from tests.http.clients.base import DebuggableGraphQLTransportWSHandler
from tests.views.schema import MyExtension, Schema
from tests.views.schema import MyExtension, Schema, Subscription

if TYPE_CHECKING:
from tests.http.clients.base import HttpClient, WebSocketClient
Expand Down Expand Up @@ -1104,12 +1104,14 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(
{
"id": "sub1",
"type": "subscribe",
"payload": {
"query": 'subscription { echo(message: "Hi", delay: 0.5) }'
},
"payload": {"query": 'subscription { infinity(message: "Hi") }'},
}
)
await ws.receive(timeout=2)
assert Subscription.active_infinity_subscriptions == 1

await ws.close()
await asyncio.sleep(1)

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0
8 changes: 6 additions & 2 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ErrorMessage,
StartMessage,
)
from tests.views.schema import MyExtension, Schema
from tests.views.schema import MyExtension, Schema, Subscription

if TYPE_CHECKING:
from tests.http.clients.aiohttp import HttpClient, WebSocketClient
Expand Down Expand Up @@ -806,11 +806,15 @@ async def test_unexpected_client_disconnects_are_gracefully_handled(
"type": "start",
"id": "sub1",
"payload": {
"query": 'subscription { echo(message: "Hi", delay: 0.5) }',
"query": 'subscription { infinity(message: "Hi") }',
},
}
)
await ws.receive_json()
assert Subscription.active_infinity_subscriptions == 1

await ws.close()
await asyncio.sleep(1)

assert not process_errors.called
assert Subscription.active_infinity_subscriptions == 0
Loading