Skip to content
Snippets Groups Projects
Commit 22578545 authored by Erik Johnston's avatar Erik Johnston
Browse files

Time out typing over federation

parent 667fcd54
Branches
Tags
No related merge requests found
......@@ -136,9 +136,7 @@ class FederationClient(FederationBase):
sent_edus_counter.inc()
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu, key=key)
return defer.succeed(None)
@log_function
def send_device_messages(self, destination):
......
......@@ -16,10 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
import logging
......@@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
# How often we expect remote servers to resend us presence.
FEDERATION_TIMEOUT = 60 * 1000
# How often to resend typing across federation.
FEDERATION_PING_INTERVAL = 40 * 1000
class TypingHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
......@@ -44,7 +50,10 @@ class TypingHandler(object):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.hs = hs
self.clock = hs.get_clock()
self.wheel_timer = WheelTimer()
self.federation = hs.get_replication_layer()
......@@ -53,7 +62,7 @@ class TypingHandler(object):
hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop
self._member_typing_timer = {} # deferreds to manage theabove
self._member_last_federation_poke = {}
# map room IDs to serial numbers
self._room_serials = {}
......@@ -61,12 +70,41 @@ class TypingHandler(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
def tearDown(self):
"""Cancels all the pending timers.
Normally this shouldn't be needed, but it's required from unit tests
to avoid a "Reactor was unclean" warning."""
for t in self._member_typing_timer.values():
self.clock.cancel_call_later(t)
self.clock.looping_call(
self._handle_timeouts,
5000,
)
def _handle_timeouts(self):
logger.info("Handling typing timeout")
now = self.clock.time_msec()
members = set(self.wheel_timer.fetch(now))
for member in members:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
continue
until = self._member_typing_until.get(member, None)
if not until or until < now:
logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member)
continue
# Check if we need to resend a keep alive over federation for this
# user.
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
preserve_fn(self._push_remote)(
member=member,
typing=True
)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
......@@ -85,23 +123,23 @@ class TypingHandler(object):
"%s has started typing in %s", target_user_id, room_id
)
until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until
was_present = member.user_id in self._room_typing.get(room_id, set())
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
now = self.clock.time_msec()
self._member_typing_until[member] = now + timeout
def _cb():
logger.debug(
"%s has timed out in %s", target_user.to_string(), room_id
)
self._stopped_typing(member)
self.wheel_timer.insert(
now=now,
obj=member,
then=now + timeout,
)
self._member_typing_until[member] = until
self._member_typing_timer[member] = self.clock.call_later(
timeout / 1000.0, _cb
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_PING_INTERVAL,
)
if was_present:
......@@ -109,8 +147,7 @@ class TypingHandler(object):
defer.returnValue(None)
yield self._push_update(
room_id=room_id,
user_id=target_user_id,
member=member,
typing=True,
)
......@@ -133,10 +170,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
del self._member_typing_timer[member]
yield self._stopped_typing(member)
@defer.inlineCallbacks
......@@ -148,57 +181,53 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _stopped_typing(self, member):
if member not in self._member_typing_until:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
defer.returnValue(None)
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
yield self._push_update(
room_id=member.room_id,
user_id=member.user_id,
member=member,
typing=False,
)
del self._member_typing_until[member]
if member in self._member_typing_timer:
# Don't cancel it - either it already expired, or the real
# stopped_typing() will cancel it
del self._member_typing_timer[member]
@defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing):
users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
yield self._push_remote(member, typing)
self._push_update_local(
member=member,
typing=typing
)
deferreds = []
for domain in domains:
if domain == self.server_name:
preserve_fn(self._push_update_local)(
room_id=room_id,
user_id=user_id,
typing=typing
)
else:
deferreds.append(preserve_fn(self.federation.send_edu)(
@defer.inlineCallbacks
def _push_remote(self, member, typing):
users = yield self.state.get_current_user_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
for domain in set(get_domain_from_id(u) for u in users):
if domain != self.server_name:
self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user_id,
"room_id": member.room_id,
"user_id": member.user_id,
"typing": typing,
},
key=(room_id, user_id),
))
yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
key=member,
)
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
room_id = content["room_id"]
user_id = content["user_id"]
member = RoomMember(user_id=user_id, room_id=room_id)
# Check that the string is a valid user id
user = UserID.from_string(user_id)
......@@ -213,26 +242,32 @@ class TypingHandler(object):
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_TIMEOUT,
)
self._push_update_local(
room_id=room_id,
user_id=user_id,
member=member,
typing=content["typing"]
)
def _push_update_local(self, room_id, user_id, typing):
room_set = self._room_typing.setdefault(room_id, set())
def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(user_id)
room_set.add(member.user_id)
else:
room_set.discard(user_id)
room_set.discard(member.user_id)
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
self._room_serials[member.room_id] = self._latest_room_serial
with PreserveLoggingContext():
self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id]
)
self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state.
......
......@@ -705,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet):
yield self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]:
yield self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=content.get("timeout", 30000),
timeout=timeout,
)
else:
yield self.typing_handler.stopped_typing(
......
......@@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000
self.handler._member_typing_timer[member] = (
self.clock.call_later(1002, lambda: 0)
)
self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
self.assertEquals(self.event_source.get_current_key(), 0)
......@@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
},
}])
self.clock.advance_time(11)
self.clock.advance_time(16)
self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]),
......
......@@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
# Need another user to make notifications actually work
yield self.join(self.room_id, user="@jim:red")
def tearDown(self):
self.hs.get_typing_handler().tearDown()
@defer.inlineCallbacks
def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger(
......@@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
self.clock.advance_time(31)
self.clock.advance_time(36)
self.assertEquals(self.event_source.get_current_key(), 2)
......
......@@ -220,6 +220,7 @@ class MockClock(object):
# list of lists of [absolute_time, callback, expired] in no particular
# order
self.timers = []
self.loopers = []
def time(self):
return self.now
......@@ -240,7 +241,7 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
pass
self.loopers.append([function, interval / 1000., self.now])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
......@@ -269,6 +270,12 @@ class MockClock(object):
else:
self.timers.append(t)
for looped in self.loopers:
func, interval, last = looped
if last + interval < self.now:
func()
looped[2] = self.now
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment