Skip to content
Snippets Groups Projects
Unverified Commit 316590d1 authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Fix bug in `wait_for_stream_position` (#14856)

We were incorrectly checking if the *local* token had been advanced, rather than the token for the remote instance.

In practice, I don't think this has caused any bugs due to where we use `wait_for_stream_position`, as critically we don't use it on instances that also write to the given streams (and so the local token will lag behind all remote tokens).
parent 2b084c5b
No related branches found
No related tags found
No related merge requests found
Fix `wait_for_stream_position` to correctly wait for the right instance to advance its token.
...@@ -325,7 +325,7 @@ class ReplicationDataHandler: ...@@ -325,7 +325,7 @@ class ReplicationDataHandler:
# anyway in that case we don't need to wait. # anyway in that case we don't need to wait.
return return
current_position = self._streams[stream_name].current_token(self._instance_name) current_position = self._streams[stream_name].current_token(instance_name)
if position <= current_position: if position <= current_position:
# We're already past the position # We're already past the position
return return
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.replication.tcp.commands import PositionCommand, RdataCommand
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
...@@ -71,3 +75,77 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): ...@@ -71,3 +75,77 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual( self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1 len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
) )
def test_wait_for_stream_position(self) -> None:
"""Check that wait for stream position correctly waits for an update from the
correct instance.
"""
store = self.hs.get_datastores().main
cmd_handler = self.hs.get_replication_command_handler()
data_handler = self.hs.get_replication_data_handler()
worker1 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker1",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
cache_id_gen = worker1.get_datastores().main._cache_id_gen
assert cache_id_gen is not None
self.replicate()
# First, make sure the master knows that `worker1` exists.
initial_token = cache_id_gen.get_current_token()
cmd_handler.send_command(
PositionCommand("caches", "worker1", initial_token, initial_token)
)
self.replicate()
# Next send out a normal RDATA, and check that waiting for that stream
# ID returns immediately.
ctx = cache_id_gen.get_next()
next_token = self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
cmd_handler.send_command(
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
)
self.replicate()
self.get_success(
data_handler.wait_for_stream_position("worker1", "caches", next_token)
)
# `wait_for_stream_position` should only return once master receives an
# RDATA from the worker
ctx = cache_id_gen.get_next()
next_token = self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
d = defer.ensureDeferred(
data_handler.wait_for_stream_position("worker1", "caches", next_token)
)
self.assertFalse(d.called)
# ... updating the cache ID gen on the master still shouldn't cause the
# deferred to wake up.
ctx = store._cache_id_gen.get_next()
self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
d = defer.ensureDeferred(
data_handler.wait_for_stream_position("worker1", "caches", next_token)
)
self.assertFalse(d.called)
# ... but receiving the RDATA should
cmd_handler.send_command(
RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
)
self.replicate()
self.assertTrue(d.called)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment