Skip to content
Snippets Groups Projects
Unverified Commit c255b0ff authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Merge pull request #7427 from matrix-org/rav/fix_dropped_messages

Fix lost events on replication reconnection
parents 5b8023dc a8c17da2
No related branches found
No related tags found
No related merge requests found
Add support for running replication over Redis when using workers.
...@@ -81,9 +81,6 @@ class ReplicationCommandHandler: ...@@ -81,9 +81,6 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id() self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
self._streams = { self._streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values() stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream] } # type: Dict[str, Stream]
...@@ -99,9 +96,13 @@ class ReplicationCommandHandler: ...@@ -99,9 +96,13 @@ class ReplicationCommandHandler:
# The factory used to create connections. # The factory used to create connections.
self._factory = None # type: Optional[ReconnectingClientFactory] self._factory = None # type: Optional[ReconnectingClientFactory]
# The currently connected connections. # The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection] self._connections = [] # type: List[AbstractConnection]
# For each connection, the incoming streams that are coming from that connection
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge( LaterGauge(
"synapse_replication_tcp_resource_total_connections", "synapse_replication_tcp_resource_total_connections",
"", "",
...@@ -257,9 +258,11 @@ class ReplicationCommandHandler: ...@@ -257,9 +258,11 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching # 2. so we don't race with getting a POSITION command and fetching
# missing RDATA. # missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name): with await self._position_linearizer.queue(cmd.stream_name):
if stream_name not in self._streams_connected: # make sure that we've processed a POSITION for this stream *on this
# If the stream isn't marked as connected then we haven't seen a # connection*. (A POSITION on another connection is no good, as there
# `POSITION` command yet, and so we may have missed some rows. # is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a # Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then. # `POSITION` soon and we'll catch up correctly then.
logger.debug( logger.debug(
...@@ -302,21 +305,25 @@ class ReplicationCommandHandler: ...@@ -302,21 +305,25 @@ class ReplicationCommandHandler:
# Ignore POSITION that are just our own echoes # Ignore POSITION that are just our own echoes
return return
stream = self._streams.get(cmd.stream_name) logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
stream_name = cmd.stream_name
stream = self._streams.get(stream_name)
if not stream: if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) logger.error("Got POSITION for unknown stream: %s", stream_name)
return return
# We protect catching up with a linearizer in case the replication # We protect catching up with a linearizer in case the replication
# connection reconnects under us. # connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name): with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set # We're about to go and catch up with the stream, so remove from set
# of connected streams. # of connected streams.
self._streams_connected.discard(cmd.stream_name) for streams in self._streams_by_connection.values():
streams.discard(stream_name)
# We clear the pending batches for the stream as the fetching of the # We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch. # missing updates below will fetch all rows in the batch.
self._pending_batches.pop(cmd.stream_name, []) self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to. # Find where we previously streamed up to.
current_token = stream.current_token() current_token = stream.current_token()
...@@ -326,6 +333,12 @@ class ReplicationCommandHandler: ...@@ -326,6 +333,12 @@ class ReplicationCommandHandler:
# between then and now. # between then and now.
missing_updates = cmd.token != current_token missing_updates = cmd.token != current_token
while missing_updates: while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
( (
updates, updates,
current_token, current_token,
...@@ -341,16 +354,18 @@ class ReplicationCommandHandler: ...@@ -341,16 +354,18 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates): for token, rows in _batch_updates(updates):
await self.on_rdata( await self.on_rdata(
cmd.stream_name, stream_name,
cmd.instance_name, cmd.instance_name,
token, token,
[stream.parse_row(row) for row in rows], [stream.parse_row(row) for row in rows],
) )
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler. # We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) await self._replication_data_handler.on_position(stream_name, cmd.token)
self._streams_connected.add(cmd.stream_name) self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP( async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand self, conn: AbstractConnection, cmd: RemoteServerUpCommand
...@@ -408,6 +423,12 @@ class ReplicationCommandHandler: ...@@ -408,6 +423,12 @@ class ReplicationCommandHandler:
def lost_connection(self, connection: AbstractConnection): def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost. """Called when a connection is closed/lost.
""" """
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
if streams:
logger.info(
"Lost replication connection; streams now disconnected: %s", streams
)
try: try:
self._connections.remove(connection) self._connections.remove(connection)
except ValueError: except ValueError:
......
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
import txredisapi import txredisapi
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
Command, Command,
...@@ -41,8 +41,14 @@ logger = logging.getLogger(__name__) ...@@ -41,8 +41,14 @@ logger = logging.getLogger(__name__)
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
"""Connection to redis subscribed to replication stream. """Connection to redis subscribed to replication stream.
Parses incoming messages from redis into replication commands, and passes This class fulfils two functions:
them to `ReplicationCommandHandler`
(a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom Due to the vagaries of `txredisapi` we don't want to have a custom
constructor, so instead we expect the defined attributes below to be set constructor, so instead we expect the defined attributes below to be set
...@@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): ...@@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Attributes: Attributes:
handler: The command handler to handle incoming commands. handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to (not anything to stream_name: The *redis* stream name to subscribe to and publish from
do with Synapse replication streams). (not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send outbound_redis_connection: The connection to redis to use to send
commands. commands.
""" """
...@@ -61,13 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): ...@@ -61,13 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
outbound_redis_connection = None # type: txredisapi.RedisProtocol outbound_redis_connection = None # type: txredisapi.RedisProtocol
def connectionMade(self): def connectionMade(self):
logger.info("Connected to redis")
super().connectionMade() super().connectionMade()
logger.info("Connected to redis instance") run_as_background_process("subscribe-replication", self._send_subscribe)
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())
self.handler.new_connection(self) self.handler.new_connection(self)
async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
def messageReceived(self, pattern: str, channel: str, message: str): def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis. """Received a message from redis.
""" """
...@@ -120,8 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): ...@@ -120,8 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.warning("Unhandled command: %r", cmd) logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason) super().connectionLost(reason)
logger.info("Lost connection to redis instance")
self.handler.lost_connection(self) self.handler.lost_connection(self)
def send_command(self, cmd: Command): def send_command(self, cmd: Command):
...@@ -130,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): ...@@ -130,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Args: Args:
cmd (Command) cmd (Command)
""" """
run_as_background_process("send-cmd", self._async_send_command, cmd)
async def _async_send_command(self, cmd: Command):
"""Encode a replication command and send it over our outbound connection"""
string = "%s %s" % (cmd.NAME, cmd.to_line()) string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string: if "\n" in string:
raise Exception("Unexpected newline in command: %r", string) raise Exception("Unexpected newline in command: %r", string)
...@@ -140,15 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): ...@@ -140,15 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances. # remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
async def _send(): await make_deferred_yieldable(
with PreserveLoggingContext(): self.outbound_redis_connection.publish(self.stream_name, encoded_string)
# Note that we use the other connection as we can't send )
# commands using the subscription connection.
await self.outbound_redis_connection.publish(
self.stream_name, encoded_string
)
run_as_background_process("send-cmd", _send)
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment