Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Split fetching device keys and signatures into two transactions #8233

Merged
merged 3 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8233.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.
109 changes: 65 additions & 44 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])

# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
# cross-signing sigs on this device.
# dict from (signing user_id)->(signing device_id)->sig
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)


class EndToEndKeyWorkerStore(SQLBaseStore):
Expand Down Expand Up @@ -133,7 +135,10 @@ async def get_e2e_device_keys_and_signatures(
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.
"""Fetch a list of device keys

Any cross-signatures made on the keys by the owner of the device are also
included.

Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
Expand All @@ -154,22 +159,51 @@ async def get_e2e_device_keys_and_signatures(

result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)

# get the (user_id, device_id) tuples to look up cross-signatures for
signature_query = (
(user_id, device_id)
for user_id, dev in result.items()
for device_id, d in dev.items()
if d is not None
)

for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)

# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
target_device_signatures = target_device_result.signatures

signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature

log_kv(result)
return result

def _get_e2e_device_keys_and_signatures_txn(
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database

The results include the device's keys and self-signatures, but *not* any
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = []
query_params = []
signature_query_clauses = []
signature_query_params = []

if include_all_devices is False:
include_deleted_devices = False
Expand All @@ -180,20 +214,12 @@ def _get_e2e_device_keys_and_signatures_txn(
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)

if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)

signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)

query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)

sql = (
"SELECT user_id, device_id, "
Expand Down Expand Up @@ -221,41 +247,36 @@ def _get_e2e_device_keys_and_signatures_txn(
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None

# get signatures on the device
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)
return result

txn.execute(signature_sql, signature_query_params)
rows = self.db_pool.cursor_to_dict(txn)

# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]

target_user_result = result.get(target_user_id)
if not target_user_result:
continue
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices

target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
Returns signatures made by the owners of the devices.

target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
"""
signature_query_clauses = []
signature_query_params = []

signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
for (user_id, device_id) in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signing_user_signatures[signing_key_id] = signature
signature_query_params.extend([user_id, device_id, user_id])

return result
signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)

txn.execute(signature_sql, signature_query_params)
return txn.fetchall()

async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str]
Expand Down