Skip to content
Snippets Groups Projects
Unverified Commit 63593134 authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Some cleanups to device inbox store. (#9041)

parent 9066c2fd
No related branches found
No related tags found
No related merge requests found
Various cleanups to device inbox store.
...@@ -18,7 +18,6 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker ...@@ -18,7 +18,6 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
...@@ -37,13 +36,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): ...@@ -37,13 +36,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
self._device_inbox_id_gen.get_current_token(), self._device_inbox_id_gen.get_current_token(),
) )
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME: if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token) self._device_inbox_id_gen.advance(instance_name, token)
......
...@@ -17,7 +17,7 @@ import logging ...@@ -17,7 +17,7 @@ import logging
from typing import List, Tuple from typing import List, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
...@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__) ...@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
...@@ -310,20 +322,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): ...@@ -310,20 +322,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
@trace @trace
async def add_messages_to_device_inbox( async def add_messages_to_device_inbox(
self, self,
...@@ -351,16 +349,19 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ...@@ -351,16 +349,19 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add the remote messages to the federation outbox. # Add the remote messages to the federation outbox.
# We'll send them to a remote server when we next send a # We'll send them to a remote server when we next send a
# federation transaction to that destination. # federation transaction to that destination.
sql = ( self.db_pool.simple_insert_many_txn(
"INSERT INTO device_federation_outbox" txn,
" (destination, stream_id, queued_ts, messages_json)" table="device_federation_outbox",
" VALUES (?,?,?,?)" values=[
{
"destination": destination,
"stream_id": stream_id,
"queued_ts": now_ms,
"messages_json": json_encoder.encode(edu),
}
for destination, edu in remote_messages_by_destination.items()
],
) )
rows = []
for destination, edu in remote_messages_by_destination.items():
edu_json = json_encoder.encode(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
async with self._device_inbox_id_gen.get_next() as stream_id: async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
...@@ -433,32 +434,37 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ...@@ -433,32 +434,37 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
devices = list(messages_by_device.keys()) devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*": if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids. # Handle wildcard device_ids.
sql = "SELECT device_id FROM devices WHERE user_id = ?" devices = self.db_pool.simple_select_onecol_txn(
txn.execute(sql, (user_id,)) txn,
table="devices",
keyvalues={"user_id": user_id},
retcol="device_id",
)
message_json = json_encoder.encode(messages_by_device["*"]) message_json = json_encoder.encode(messages_by_device["*"])
for row in txn: for device_id in devices:
# Add the message for all devices for this user on this # Add the message for all devices for this user on this
# server. # server.
device = row[0] messages_json_for_user[device_id] = message_json
messages_json_for_user[device] = message_json
else: else:
if not devices: if not devices:
continue continue
clause, args = make_in_list_sql_clause( rows = self.db_pool.simple_select_many_txn(
txn.database_engine, "device_id", devices txn,
table="devices",
keyvalues={"user_id": user_id},
column="device_id",
iterable=devices,
retcols=("device_id",),
) )
sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
# TODO: Maybe this needs to be done in batches if there are for row in rows:
# too many local devices for a given user.
txn.execute(sql, [user_id] + list(args))
for row in txn:
# Only insert into the local inbox if the device exists on # Only insert into the local inbox if the device exists on
# this server # this server
device = row[0] device_id = row["device_id"]
message_json = json_encoder.encode(messages_by_device[device]) message_json = json_encoder.encode(messages_by_device[device_id])
messages_json_for_user[device] = message_json messages_json_for_user[device_id] = message_json
if messages_json_for_user: if messages_json_for_user:
local_by_user_then_device[user_id] = messages_json_for_user local_by_user_then_device[user_id] = messages_json_for_user
...@@ -466,14 +472,17 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) ...@@ -466,14 +472,17 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
if not local_by_user_then_device: if not local_by_user_then_device:
return return
sql = ( self.db_pool.simple_insert_many_txn(
"INSERT INTO device_inbox" txn,
" (user_id, device_id, stream_id, message_json)" table="device_inbox",
" VALUES (?,?,?,?)" values=[
{
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"message_json": message_json,
}
for user_id, messages_by_device in local_by_user_then_device.items()
for device_id, message_json in messages_by_device.items()
],
) )
rows = []
for user_id, messages_by_device in local_by_user_then_device.items():
for device_id, message_json in messages_by_device.items():
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment