Skip to content
Snippets Groups Projects
Commit dd2cd931 authored by Mark Haines's avatar Mark Haines
Browse files

Test ratelimiter

parent 436b3c7d
No related branches found
No related tags found
No related merge requests found
......@@ -2,35 +2,64 @@ import collections
class Ratelimiter(object):
"""
Ratelimit message sending by user.
"""
def __init__(self):
self.message_counts = collections.OrderedDict()
def prune_message_counts(self, time_now):
for user_id in self.message_counts.keys():
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)
time_delta = time_now - time_start
if message_count - time_delta * msg_rate_hz > 0:
break
else:
del self.message_counts[user_id]
def send_message(self, user_id, time_now, msg_rate_hz, burst_count):
self.prune_message_counts(time_now)
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
"""Can the user send a message?
Args:
user_id: The user sending a message.
time_now_s: The time now.
msg_rate_hz: The long term number of messages a user can send in a
second.
burst_count: How many messages the user can send before being
limited.
Returns:
A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message.
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop(
user_id, (0., time_now, None),
user_id, (0., time_now_s, None),
)
time_delta = time_now - time_start
if message_count - time_delta * msg_rate_hz < 0:
a
if message_count - (time_now - time_start) * msg_rate_hz > burst_count:
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * msg_rate_hz
if sent_count < 0:
allowed = True
time_start = time_now_s
messagecount = 1.
elif sent_count > burst_count - 1.:
allowed = False
else:
allowed = True
message_count += 1
self.message_counts[user_id] = (
message_count, time_start, msg_rate_hz
)
return allowed
if msg_rate_hz > 0:
time_allowed = (
time_start + (message_count - burst_count + 1) / msg_rate_hz
)
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
time_allowed = -1
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys():
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)
time_delta = time_now_s - time_start
if message_count - time_delta * msg_rate_hz > 0:
break
else:
del self.message_counts[user_id]
from synapse.api.ratelimiting import Ratelimiter
import unittest
class TestRatelimiter(unittest.TestCase):
def test_allowed(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
)
self.assertTrue(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1,
)
self.assertFalse(allowed)
self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
)
self.assertTrue(allowed)
self.assertEquals(20., time_allowed)
def test_pruning(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message(
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1,
)
self.assertIn("test_id_1", limiter.message_counts)
allowed, time_allowed = limiter.send_message(
user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
)
self.assertNotIn("test_id_1", limiter.message_counts)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment