Skip to content
Snippets Groups Projects
Commit 93d90765 authored by Erik Johnston's avatar Erik Johnston
Browse files

Initial implementation of federation server rate limiting

parent a0250556
Branches
Tags
No related merge requests found
...@@ -21,7 +21,7 @@ support HTTPS), however individual pairings of servers may decide to ...@@ -21,7 +21,7 @@ support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol. communicate over a different (albeit still reliable) protocol.
""" """
from .server import TransportLayerServer from .server import TransportLayerServer, FederationRateLimiter
from .client import TransportLayerClient from .client import TransportLayerClient
...@@ -55,8 +55,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient): ...@@ -55,8 +55,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
send requests send requests
""" """
self.keyring = homeserver.get_keyring() self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name self.server_name = server_name
self.server = server self.server = server
self.client = client self.client = client
self.request_handler = None self.request_handler = None
self.received_handler = None self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=10000,
sleep_limit=10,
sleep_msec=500,
reject_limit=50,
concurrent_requests=3,
)
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError, LimitExceededError
from synapse.util.async import sleep
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import collections
import logging import logging
import simplejson as json import simplejson as json
import re import re
...@@ -27,6 +29,163 @@ import re ...@@ -27,6 +29,163 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.ratelimiters = {}
def ratelimit(self, host):
return self.ratelimiters.setdefault(
host,
PerHostRatelimiter(
clock=self.clock,
window_size=self.window_size,
sleep_limit=self.sleep_limit,
sleep_msec=self.sleep_msec,
reject_limit=self.reject_limit,
concurrent_requests=self.concurrent_requests,
)
).ratelimit()
class PerHostRatelimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.sleeping_requests = set()
self.ready_request_queue = collections.OrderedDict()
self.current_processing = set()
self.request_times = []
def is_empty(self):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
return not (
self.ready_request_queue
or self.sleeping_requests
or self.current_processing
or self.request_times
)
def ratelimit(self):
request_id = object()
def on_enter():
return self._on_enter(request_id)
def on_exit(exc_type, exc_val, exc_tb):
return self._on_exit(request_id)
return ContextManagerFunction(on_enter, on_exit)
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
retry_after_ms=int(
self.window_size / self.sleep_limit
),
)
self.request_times.append(time_now)
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
return queue_defer
else:
return defer.succeed(None)
logger.debug("Ratelimit [%s]: len(self.request_times)=%d", id(request_id), len(self.request_times))
logger.debug("Ratelimit [%s]: len(self.request_times)=%d", id(request_id), len(self.request_times))
if len(self.request_times) > self.sleep_limit:
logger.debug("Ratelimit [%s]: sleeping req", id(request_id))
ret_defer = sleep(self.sleep_msec/1000.0)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
ret_defer.addBoth(on_wait_finished)
else:
ret_defer = queue_request()
def on_start(r):
logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
def on_err(r):
self.current_processing.discard(request_id)
return r
def on_both(r):
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
return r
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
return ret_defer
def _on_exit(self, request_id):
logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
self.current_processing.add(request_id)
deferred.callback(None)
except KeyError:
pass
class ContextManagerFunction(object):
def __init__(self, on_enter, on_exit):
self.on_enter = on_enter
self.on_exit = on_exit
def __enter__(self):
if self.on_enter:
return self.on_enter()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.on_exit:
return self.on_exit(exc_type, exc_val, exc_tb)
class TransportLayerServer(object): class TransportLayerServer(object):
"""Handles incoming federation HTTP requests""" """Handles incoming federation HTTP requests"""
...@@ -98,15 +257,23 @@ class TransportLayerServer(object): ...@@ -98,15 +257,23 @@ class TransportLayerServer(object):
def new_handler(request, *args, **kwargs): def new_handler(request, *args, **kwargs):
try: try:
(origin, content) = yield self._authenticate_request(request) (origin, content) = yield self._authenticate_request(request)
response = yield handler( with self.ratelimiter.ratelimit(origin) as d:
origin, content, request.args, *args, **kwargs yield d
) response = yield handler(
origin, content, request.args, *args, **kwargs
)
except: except:
logger.exception("_authenticate_request failed") logger.exception("_authenticate_request failed")
raise raise
defer.returnValue(response) defer.returnValue(response)
return new_handler return new_handler
def rate_limit_origin(self, handler):
def new_handler(origin, *args, **kwargs):
response = yield handler(origin, *args, **kwargs)
defer.returnValue(response)
return new_handler()
@log_function @log_function
def register_received_handler(self, handler): def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data. """ Register a handler that will be fired when we receive data.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment