Skip to content
Snippets Groups Projects
Unverified Commit ce72355d authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Fix race in replication (#7226)

Fixes a race between handling `POSITION` and `RDATA` commands. We do this by simply linearizing handling of them.
parent 82498ee9
No related branches found
No related tags found
No related merge requests found
Move catchup of replication streams logic to worker.
...@@ -189,16 +189,34 @@ class ReplicationCommandHandler: ...@@ -189,16 +189,34 @@ class ReplicationCommandHandler:
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise raise
if cmd.token is None or stream_name not in self._streams_connected: # We linearize here for two reasons:
# I.e. either this is part of a batch of updates for this stream (in # 1. so we don't try and concurrently handle multiple rows for the
# which case batch until we get an update for the stream with a non # same stream, and
# None token) or we're currently connecting so we queue up rows. # 2. so we don't race with getting a POSITION command and fetching
self._pending_batches.setdefault(stream_name, []).append(row) # missing RDATA.
else: with await self._position_linearizer.queue(cmd.stream_name):
# Check if this is the last of a batch of updates if stream_name not in self._streams_connected:
rows = self._pending_batches.pop(stream_name, []) # If the stream isn't marked as connected then we haven't seen a
rows.append(row) # `POSITION` command yet, and so we may have missed some rows.
await self.on_rdata(stream_name, cmd.token, rows) # Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.warning(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list): async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
...@@ -221,12 +239,13 @@ class ReplicationCommandHandler: ...@@ -221,12 +239,13 @@ class ReplicationCommandHandler:
# We protect catching up with a linearizer in case the replication # We protect catching up with a linearizer in case the replication
# connection reconnects under us. # connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name): with await self._position_linearizer.queue(cmd.stream_name):
# We're about to go and catch up with the stream, so mark as connecting # We're about to go and catch up with the stream, so remove from set
# to stop RDATA being handled at the same time by removing stream from # of connected streams.
# list of connected streams. We also clear any batched up RDATA from
# before we got the POSITION.
self._streams_connected.discard(cmd.stream_name) self._streams_connected.discard(cmd.stream_name)
self._pending_batches.clear()
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(cmd.stream_name, [])
# Find where we previously streamed up to. # Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get( current_token = self._replication_data_handler.get_streams_to_replicate().get(
...@@ -239,12 +258,17 @@ class ReplicationCommandHandler: ...@@ -239,12 +258,17 @@ class ReplicationCommandHandler:
) )
return return
# Fetch all updates between then and now. # If the position token matches our current token then we're up to
limited = True # date and there's nothing to do. Otherwise, fetch all updates
while limited: # between then and now.
updates, current_token, limited = await stream.get_updates_since( missing_updates = cmd.token != current_token
current_token, cmd.token while missing_updates:
) (
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)
if updates: if updates:
await self.on_rdata( await self.on_rdata(
cmd.stream_name, cmd.stream_name,
...@@ -255,13 +279,6 @@ class ReplicationCommandHandler: ...@@ -255,13 +279,6 @@ class ReplicationCommandHandler:
# We've now caught up to position sent to us, notify handler. # We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
# Handle any RDATA that came in while we were catching up.
rows = self._pending_batches.pop(cmd.stream_name, [])
if rows:
await self._replication_data_handler.on_rdata(
cmd.stream_name, rows[-1].token, rows
)
self._streams_connected.add(cmd.stream_name) self._streams_connected.add(cmd.stream_name)
async def on_SYNC(self, cmd: SyncCommand): async def on_SYNC(self, cmd: SyncCommand):
......
...@@ -168,12 +168,13 @@ def make_http_update_function( ...@@ -168,12 +168,13 @@ def make_http_update_function(
async def update_function( async def update_function(
from_token: int, upto_token: int, limit: int from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]: ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client( result = await client(
stream_name=stream_name, stream_name=stream_name,
from_token=from_token, from_token=from_token,
upto_token=upto_token, upto_token=upto_token,
limit=limit, limit=limit,
) )
return result["updates"], result["upto_token"], result["limited"]
return update_function return update_function
......
...@@ -334,6 +334,26 @@ class PushRulesWorkerStore( ...@@ -334,6 +334,26 @@ class PushRulesWorkerStore(
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results return results
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
if last_id == current_id:
return defer.succeed([])
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
class PushRuleStore(PushRulesWorkerStore): class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks @defer.inlineCallbacks
...@@ -685,26 +705,6 @@ class PushRuleStore(PushRulesWorkerStore): ...@@ -685,26 +705,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
) )
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
if last_id == current_id:
return defer.succeed([])
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
def get_push_rules_stream_token(self): def get_push_rules_stream_token(self):
"""Get the position of the push rules stream. """Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the Returns a pair of a stream id for the push_rules stream and the
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment