Skip to content
Snippets Groups Projects
Commit dd2cd931 authored by Mark Haines's avatar Mark Haines
Browse files

Test ratelimiter

parent 436b3c7d
No related branches found
No related tags found
No related merge requests found
...@@ -2,35 +2,64 @@ import collections ...@@ -2,35 +2,64 @@ import collections
class Ratelimiter(object): class Ratelimiter(object):
"""
Ratelimit message sending by user.
"""
def __init__(self): def __init__(self):
self.message_counts = collections.OrderedDict() self.message_counts = collections.OrderedDict()
def prune_message_counts(self, time_now): def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
for user_id in self.message_counts.keys(): """Can the user send a message?
message_count, time_start, msg_rate_hz = ( Args:
self.message_counts[user_id] user_id: The user sending a message.
) time_now_s: The time now.
time_delta = time_now - time_start msg_rate_hz: The long term number of messages a user can send in a
if message_count - time_delta * msg_rate_hz > 0: second.
break burst_count: How many messages the user can send before being
else: limited.
del self.message_counts[user_id] Returns:
A pair of a bool indicating if they can send a message now and a
def send_message(self, user_id, time_now, msg_rate_hz, burst_count): time in seconds of when they can next send a message.
self.prune_message_counts(time_now) """
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop( message_count, time_start, _ignored = self.message_counts.pop(
user_id, (0., time_now, None), user_id, (0., time_now_s, None),
) )
time_delta = time_now - time_start time_delta = time_now_s - time_start
if message_count - time_delta * msg_rate_hz < 0: sent_count = message_count - time_delta * msg_rate_hz
a if sent_count < 0:
if message_count - (time_now - time_start) * msg_rate_hz > burst_count: allowed = True
time_start = time_now_s
messagecount = 1.
elif sent_count > burst_count - 1.:
allowed = False allowed = False
else: else:
allowed = True allowed = True
message_count += 1 message_count += 1
self.message_counts[user_id] = ( self.message_counts[user_id] = (
message_count, time_start, msg_rate_hz message_count, time_start, msg_rate_hz
) )
return allowed
if msg_rate_hz > 0:
time_allowed = (
time_start + (message_count - burst_count + 1) / msg_rate_hz
)
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
time_allowed = -1
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys():
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)
time_delta = time_now_s - time_start
if message_count - time_delta * msg_rate_hz > 0:
break
else:
del self.message_counts[user_id]
from synapse.api.ratelimiting import Ratelimiter
import unittest
class TestRatelimiter(unittest.TestCase):
def test_allowed(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
)
self.assertTrue(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1,
)
self.assertFalse(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed)
self.assertEquals(20., time_allowed)
def test_pruning(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
)
self.assertIn("test_id_1", limiter.message_counts)
allowed, time_allowed = limiter.send_message(
user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
)
self.assertNotIn("test_id_1", limiter.message_counts)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment