Skip to content
Snippets Groups Projects
Unverified Commit 599c403d authored by David Robertson's avatar David Robertson Committed by GitHub
Browse files

Allow rate limiters to passively record actions they cannot limit (#13253)

parent 0eb7e697
No related branches found
No related tags found
No related merge requests found
Preparatory work for a per-room rate limiter on joins.
...@@ -27,6 +27,33 @@ class Ratelimiter: ...@@ -27,6 +27,33 @@ class Ratelimiter:
""" """
Ratelimit actions marked by arbitrary keys. Ratelimit actions marked by arbitrary keys.
(Note that the source code speaks of "actions" and "burst_count" rather than
"tokens" and a "bucket_size".)
This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
permitted requests for that key. Each bucket starts empty, and gradually leaks
tokens at a rate of `rate_hz`.
Upon an incoming request, we must determine:
- the key that this request falls under (which bucket to inspect), and
- the cost C of this request in tokens.
Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
the request is permitted and `cost` tokens are added to the bucket.
Otherwise the request is denied, and the bucket continues to hold T tokens.
This means that the limiter enforces an average request frequency of `rate_hz`,
while accumulating a buffer of up to `burst_count` requests which can be consumed
instantaneously.
The tricky bit is the leaking. We do not want to have a periodic process which
leaks every bucket! Instead, we track
- the time point when the bucket was last completely empty, and
- how many tokens have added to the bucket permitted since then.
Then for each incoming request, we can calculate how many tokens have leaked
since this time point, and use that to decide if we should accept or reject the
request.
Args: Args:
clock: A homeserver clock, for retrieving the current time clock: A homeserver clock, for retrieving the current time
rate_hz: The long term number of actions that can be performed in a second. rate_hz: The long term number of actions that can be performed in a second.
...@@ -41,14 +68,30 @@ class Ratelimiter: ...@@ -41,14 +68,30 @@ class Ratelimiter:
self.burst_count = burst_count self.burst_count = burst_count
self.store = store self.store = store
# A ordered dictionary keeping track of actions, when they were last # An ordered dictionary representing the token buckets tracked by this rate
# performed and how often. Each entry is a mapping from a key of arbitrary type # limiter. Each entry maps a key of arbitrary type to a tuple representing:
# to a tuple representing: # * The number of tokens currently in the bucket,
# * How many times an action has occurred since a point in time # * The time point when the bucket was last completely empty, and
# * The point in time # * The rate_hz (leak rate) of this particular bucket.
# * The rate_hz of this particular entry. This can vary per request
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
def _get_key(
self, requester: Optional[Requester], key: Optional[Hashable]
) -> Hashable:
"""Use the requester's MXID as a fallback key if no key is provided."""
if key is None:
if not requester:
raise ValueError("Must supply at least one of `requester` or `key`")
key = requester.user.to_string()
return key
def _get_action_counts(
self, key: Hashable, time_now_s: float
) -> Tuple[float, float, float]:
"""Retrieve the action counts, with a fallback representing an empty bucket."""
return self.actions.get(key, (0.0, time_now_s, 0.0))
async def can_do_action( async def can_do_action(
self, self,
requester: Optional[Requester], requester: Optional[Requester],
...@@ -88,11 +131,7 @@ class Ratelimiter: ...@@ -88,11 +131,7 @@ class Ratelimiter:
* The reactor timestamp for when the action can be performed next. * The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero -1 if rate_hz is less than or equal to zero
""" """
if key is None: key = self._get_key(requester, key)
if not requester:
raise ValueError("Must supply at least one of `requester` or `key`")
key = requester.user.to_string()
if requester: if requester:
# Disable rate limiting of users belonging to any AS that is configured # Disable rate limiting of users belonging to any AS that is configured
...@@ -121,7 +160,7 @@ class Ratelimiter: ...@@ -121,7 +160,7 @@ class Ratelimiter:
self._prune_message_counts(time_now_s) self._prune_message_counts(time_now_s)
# Check if there is an existing count entry for this key # Check if there is an existing count entry for this key
action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) action_count, time_start, _ = self._get_action_counts(key, time_now_s)
# Check whether performing another action is allowed # Check whether performing another action is allowed
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
...@@ -164,6 +203,37 @@ class Ratelimiter: ...@@ -164,6 +203,37 @@ class Ratelimiter:
return allowed, time_allowed return allowed, time_allowed
def record_action(
self,
requester: Optional[Requester],
key: Optional[Hashable] = None,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
) -> None:
"""Record that an action(s) took place, even if they violate the rate limit.
This is useful for tracking the frequency of events that happen across
federation which we still want to impose local rate limits on. For instance, if
we are alice.com monitoring a particular room, we cannot prevent bob.com
from joining users to that room. However, we can track the number of recent
joins in the room and refuse to serve new joins ourselves if there have been too
many in the room across both homeservers.
Args:
requester: The requester that is doing the action, if any.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
n_actions: The number of times the user wants to do this action. If the user
cannot do all of the actions, the user's action count is not incremented
at all.
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
"""
key = self._get_key(requester, key)
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
self.actions[key] = (action_count + n_actions, time_start, rate_hz)
def _prune_message_counts(self, time_now_s: float) -> None: def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined """Remove message count entries that have not exceeded their defined
rate_hz limit rate_hz limit
......
...@@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase): ...@@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Check that we get rate limited after using that token. # Check that we get rate limited after using that token.
self.assertFalse(consume_at(11.1)) self.assertFalse(consume_at(11.1))
def test_record_action_which_doesnt_fill_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)
# Observe two actions, leaving room in the bucket for one more.
limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
# We should be able to take a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertTrue(success)
# ... but not two.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)
def test_record_action_which_fills_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)
# Observe three actions, filling up the bucket.
limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
# We should be unable to take a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)
# If we wait 10 seconds to leak a token, we should be able to take one action...
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertTrue(success)
# ... but not two.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertFalse(success)
def test_record_action_which_overfills_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)
# Observe four actions, exceeding the bucket.
limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
# We should be prevented from taking a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)
# If we wait 10 seconds to leak a token, we should be unable to take an action
# because the bucket is still full.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertFalse(success)
# But after another 10 seconds we leak a second token, giving us room for
# action.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
)
self.assertTrue(success)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment