Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Additional type hints for relations database class. (#11205)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Oct 28, 2021
1 parent 0e16b41 commit 56e281b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
1 change: 1 addition & 0 deletions changelog.d/11205.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints for the relations datastore.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ files =
synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/relations.py,
synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
Expand Down
38 changes: 23 additions & 15 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

import logging
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import attr

from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_relations_for_event(
"""

where_clause = ["relates_to_id = ?"]
where_args = [event_id]
where_args: List[Union[str, int]] = [event_id]

if relation_type is not None:
where_clause.append("relation_type = ?")
Expand All @@ -80,8 +81,8 @@ async def get_relations_for_event(
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)

Expand All @@ -106,7 +107,9 @@ async def get_relations_for_event(
order,
)

def _get_recent_references_for_event_txn(txn):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])

last_topo_id = None
Expand Down Expand Up @@ -160,7 +163,7 @@ async def get_aggregation_groups_for_event(
"""

where_clause = ["relates_to_id = ?", "relation_type = ?"]
where_args = [event_id, RelationTypes.ANNOTATION]
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]

if event_type:
where_clause.append("type = ?")
Expand All @@ -169,8 +172,8 @@ async def get_aggregation_groups_for_event(
having_clause = generate_pagination_where_clause(
direction=direction,
column_names=("COUNT(*)", "MAX(stream_ordering)"),
from_token=attr.astuple(from_token) if from_token else None,
to_token=attr.astuple(to_token) if to_token else None,
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
engine=self.database_engine,
)

Expand Down Expand Up @@ -199,7 +202,9 @@ async def get_aggregation_groups_for_event(
having_clause=having_clause,
)

def _get_aggregation_groups_for_event_txn(txn):
def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
) -> PaginationChunk:
txn.execute(sql, where_args + [limit + 1])

next_batch = None
Expand Down Expand Up @@ -254,11 +259,12 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
LIMIT 1
"""

def _get_applicable_edit_txn(txn):
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
return None

edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
Expand All @@ -267,7 +273,7 @@ def _get_applicable_edit_txn(txn):
if not edit_id:
return None

return await self.get_event(edit_id, allow_none=True)
return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]

@cached()
async def get_thread_summary(
Expand All @@ -283,7 +289,9 @@ async def get_thread_summary(
The number of items in the thread and the most recent response, if any.
"""

def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
def _get_thread_summary_txn(
txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
sql = """
Expand Down Expand Up @@ -312,7 +320,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
count = txn.fetchone()[0]
count = txn.fetchone()[0] # type: ignore[index]

return count, latest_event_id

Expand All @@ -322,7 +330,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:

latest_event = None
if latest_event_id:
latest_event = await self.get_event(latest_event_id, allow_none=True)
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]

return count, latest_event

Expand Down Expand Up @@ -354,7 +362,7 @@ async def has_user_annotated_event(
LIMIT 1;
"""

def _get_if_user_has_annotated_event(txn):
def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
txn.execute(
sql,
(
Expand Down

0 comments on commit 56e281b

Please sign in to comment.