Skip to content
Snippets Groups Projects
Unverified Commit 0e719f23 authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Thread through instance name to replication client. (#7369)

For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams.
parent 3085cde5
Branches
Tags
No related merge requests found
Thread through instance name to replication client.
...@@ -646,13 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): ...@@ -646,13 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
else: else:
self.send_handler = None self.send_handler = None
async def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, instance_name, token, rows):
await super(GenericWorkerReplicationHandler, self).on_rdata( await super().on_rdata(stream_name, instance_name, token, rows)
stream_name, token, rows await self._process_and_notify(stream_name, instance_name, token, rows)
)
await self.process_and_notify(stream_name, token, rows)
async def process_and_notify(self, stream_name, token, rows): async def _process_and_notify(self, stream_name, instance_name, token, rows):
try: try:
if self.send_handler: if self.send_handler:
await self.send_handler.process_replication_rows( await self.send_handler.process_replication_rows(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import abc import abc
import logging import logging
import re import re
from inspect import signature
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from six import raise_from from six import raise_from
...@@ -60,6 +61,8 @@ class ReplicationEndpoint(object): ...@@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server. must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`. Requests can be sent by calling the client returned by `make_client`.
Requests are sent to master process by default, but can be sent to other
named processes by specifying an `instance_name` keyword argument.
Attributes: Attributes:
NAME (str): A name for the endpoint, added to the path as well as used NAME (str): A name for the endpoint, added to the path as well as used
...@@ -91,6 +94,16 @@ class ReplicationEndpoint(object): ...@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) )
# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
), "`instance_name` is a reserved paramater name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
), "`instance_name` is a reserved paramater name"
assert self.METHOD in ("PUT", "POST", "GET") assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod @abc.abstractmethod
...@@ -135,7 +148,11 @@ class ReplicationEndpoint(object): ...@@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(**kwargs): def send_request(instance_name="master", **kwargs):
# Currently we only support sending requests to master process.
if instance_name != "master":
raise Exception("Unknown instance")
data = yield cls._serialize_payload(**kwargs) data = yield cls._serialize_payload(**kwargs)
url_args = [ url_args = [
......
...@@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): ...@@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
self._instance_name = hs.get_instance_name()
# We pull the streams from the replication steamer (if we try and make # We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop). # them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams() self.streams = hs.get_replication_streamer().get_streams()
...@@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): ...@@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
upto_token = parse_integer(request, "upto_token", required=True) upto_token = parse_integer(request, "upto_token", required=True)
updates, upto_token, limited = await stream.get_updates_since( updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token self._instance_name, from_token, upto_token
) )
return ( return (
......
...@@ -86,17 +86,19 @@ class ReplicationDataHandler: ...@@ -86,17 +86,19 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore): def __init__(self, store: BaseSlavedStore):
self.store = store self.store = store
async def on_rdata(self, stream_name: str, token: int, rows: list): async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to By default this just pokes the slave store. Can be overridden in subclasses to
handle more. handle more.
Args: Args:
stream_name (str): name of the replication stream for this batch of rows stream_name: name of the replication stream for this batch of rows
token (int): stream token for this batch of rows instance_name: the instance that wrote the rows.
rows (list): a list of Stream.ROW_TYPE objects as returned by token: stream token for this batch of rows
Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
self.store.process_replication_rows(stream_name, token, rows) self.store.process_replication_rows(stream_name, token, rows)
......
...@@ -278,19 +278,24 @@ class ReplicationCommandHandler: ...@@ -278,19 +278,24 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, []) rows = self._pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows) await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list): async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
Args: Args:
stream_name: name of the replication stream for this batch of rows stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row. Stream.parse_row.
""" """
logger.debug("Received rdata %s -> %s", stream_name, token) logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows) await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
...@@ -325,7 +330,9 @@ class ReplicationCommandHandler: ...@@ -325,7 +330,9 @@ class ReplicationCommandHandler:
updates, updates,
current_token, current_token,
missing_updates, missing_updates,
) = await stream.get_updates_since(current_token, cmd.token) ) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
# TODO: add some tests for this # TODO: add some tests for this
...@@ -334,7 +341,10 @@ class ReplicationCommandHandler: ...@@ -334,7 +341,10 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates): for token, rows in _batch_updates(updates):
await self.on_rdata( await self.on_rdata(
cmd.stream_name, token, [stream.parse_row(row) for row in rows], cmd.stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
) )
# We've now caught up to position sent to us, notify handler. # We've now caught up to position sent to us, notify handler.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr import attr
...@@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] ...@@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# #
# The arguments are: # The arguments are:
# #
# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the # * from_token: the previous stream token: the starting point for fetching the
# updates # updates
# * to_token: the new stream token: the point to get updates up to # * to_token: the new stream token: the point to get updates up to
...@@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] ...@@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# If there are more updates available, it should set `limited` in the result, and # If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch. # it will be called again to get the next batch.
# #
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object): class Stream(object):
...@@ -93,6 +94,7 @@ class Stream(object): ...@@ -93,6 +94,7 @@ class Stream(object):
def __init__( def __init__(
self, self,
local_instance_name: str,
current_token_function: Callable[[], Token], current_token_function: Callable[[], Token],
update_function: UpdateFunction, update_function: UpdateFunction,
): ):
...@@ -108,9 +110,11 @@ class Stream(object): ...@@ -108,9 +110,11 @@ class Stream(object):
stream tokens. See the UpdateFunction type definition for more info. stream tokens. See the UpdateFunction type definition for more info.
Args: Args:
local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above update_function: callback go get stream updates, as above
""" """
self.local_instance_name = local_instance_name
self.current_token = current_token_function self.current_token = current_token_function
self.update_function = update_function self.update_function = update_function
...@@ -135,14 +139,14 @@ class Stream(object): ...@@ -135,14 +139,14 @@ class Stream(object):
""" """
current_token = self.current_token() current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since( updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token self.local_instance_name, self.last_token, current_token
) )
self.last_token = current_token self.last_token = current_token
return updates, current_token, limited return updates, current_token, limited
async def get_updates_since( async def get_updates_since(
self, from_token: Token, upto_token: Token self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult: ) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should """Like get_updates except allows specifying from when we should
stream updates stream updates
...@@ -160,19 +164,19 @@ class Stream(object): ...@@ -160,19 +164,19 @@ class Stream(object):
return [], upto_token, False return [], upto_token, False
updates, upto_token, limited = await self.update_function( updates, upto_token, limited = await self.update_function(
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
) )
return updates, upto_token, limited return updates, upto_token, limited
def db_query_to_update_function( def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction: ) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it """Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class suitable for use as an `update_function` for the Stream class
""" """
async def update_function(from_token, upto_token, limit): async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit) rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows] updates = [(row[0], row[1:]) for row in rows]
limited = False limited = False
...@@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: ...@@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs) client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function( async def update_function(
from_token: int, upto_token: int, limit: int instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult: ) -> StreamUpdateResult:
result = await client( result = await client(
stream_name=stream_name, from_token=from_token, upto_token=upto_token, instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
) )
return result["updates"], result["upto_token"], result["limited"] return result["updates"], result["upto_token"], result["limited"]
...@@ -226,6 +233,7 @@ class BackfillStream(Stream): ...@@ -226,6 +233,7 @@ class BackfillStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token, store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows), db_query_to_update_function(store.get_all_new_backfill_event_rows),
) )
...@@ -261,7 +269,9 @@ class PresenceStream(Stream): ...@@ -261,7 +269,9 @@ class PresenceStream(Stream):
# Query master process # Query master process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__(store.get_current_presence_token, update_function) super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
)
class TypingStream(Stream): class TypingStream(Stream):
...@@ -284,7 +294,9 @@ class TypingStream(Stream): ...@@ -284,7 +294,9 @@ class TypingStream(Stream):
# Query master process # Query master process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__(typing_handler.get_current_token, update_function) super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
)
class ReceiptsStream(Stream): class ReceiptsStream(Stream):
...@@ -305,6 +317,7 @@ class ReceiptsStream(Stream): ...@@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id, store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts), db_query_to_update_function(store.get_all_updated_receipts),
) )
...@@ -322,14 +335,16 @@ class PushRulesStream(Stream): ...@@ -322,14 +335,16 @@ class PushRulesStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super(PushRulesStream, self).__init__( super(PushRulesStream, self).__init__(
self._current_token, self._update_function hs.get_instance_name(), self._current_token, self._update_function
) )
def _current_token(self) -> int: def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token() push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token return push_rules_token
async def _update_function(self, from_token: Token, to_token: Token, limit: int): async def _update_function(
self, instance_name: str, from_token: Token, to_token: Token, limit: int
):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False limited = False
...@@ -356,6 +371,7 @@ class PushersStream(Stream): ...@@ -356,6 +371,7 @@ class PushersStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token, store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows), db_query_to_update_function(store.get_all_updated_pushers_rows),
) )
...@@ -387,6 +403,7 @@ class CachesStream(Stream): ...@@ -387,6 +403,7 @@ class CachesStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token, store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches), db_query_to_update_function(store.get_all_updated_caches),
) )
...@@ -412,6 +429,7 @@ class PublicRoomsStream(Stream): ...@@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id, store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms), db_query_to_update_function(store.get_all_new_public_rooms),
) )
...@@ -432,6 +450,7 @@ class DeviceListsStream(Stream): ...@@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_device_stream_token, store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes), db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
) )
...@@ -449,6 +468,7 @@ class ToDeviceStream(Stream): ...@@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token, store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages), db_query_to_update_function(store.get_all_new_device_messages),
) )
...@@ -468,6 +488,7 @@ class TagAccountDataStream(Stream): ...@@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id, store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags), db_query_to_update_function(store.get_all_updated_tags),
) )
...@@ -487,6 +508,7 @@ class AccountDataStream(Stream): ...@@ -487,6 +508,7 @@ class AccountDataStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id, self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function), db_query_to_update_function(self._update_function),
) )
...@@ -517,6 +539,7 @@ class GroupServerStream(Stream): ...@@ -517,6 +539,7 @@ class GroupServerStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_group_stream_token, store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes), db_query_to_update_function(store.get_all_groups_changes),
) )
...@@ -534,6 +557,7 @@ class UserSignatureStream(Stream): ...@@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_device_stream_token, store.get_device_stream_token,
db_query_to_update_function( db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes store.get_all_user_signature_changes_for_remotes
......
...@@ -118,11 +118,17 @@ class EventsStream(Stream): ...@@ -118,11 +118,17 @@ class EventsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self._store = hs.get_datastore() self._store = hs.get_datastore()
super().__init__( super().__init__(
self._store.get_current_events_token, self._update_function, hs.get_instance_name(),
self._store.get_current_events_token,
self._update_function,
) )
async def _update_function( async def _update_function(
self, from_token: Token, current_token: Token, target_row_count: int self,
instance_name: str,
from_token: Token,
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult: ) -> StreamUpdateResult:
# the events stream merges together three separate sources: # the events stream merges together three separate sources:
......
...@@ -48,8 +48,8 @@ class FederationStream(Stream): ...@@ -48,8 +48,8 @@ class FederationStream(Stream):
current_token = lambda: 0 current_token = lambda: 0
update_function = self._stub_update_function update_function = self._stub_update_function
super().__init__(current_token, update_function) super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod @staticmethod
async def _stub_update_function(from_token, upto_token, limit): async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False return [], upto_token, False
...@@ -183,8 +183,8 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): ...@@ -183,8 +183,8 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
# list of received (stream_name, token, row) tuples # list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
async def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, token, rows) await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows: for r in rows:
self.received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))
......
...@@ -41,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): ...@@ -41,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# there should be one RDATA command # there should be one RDATA command
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts") self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
...@@ -71,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): ...@@ -71,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# We should now have caught up and get the missing data # We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts") self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3) self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
......
...@@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): ...@@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing") self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow row = rdata_rows[0] # type: TypingStream.TypingStreamRow
...@@ -74,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase): ...@@ -74,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(int(request.args[b"from_token"][0]), token) self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] row = rdata_rows[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment