From 211c31dbd77281d114b9f08ef61c867d56316e8c Mon Sep 17 00:00:00 2001
From: Alexander Udovichenko <udovichenko48@gmail.com>
Date: Tue, 5 Nov 2024 21:08:17 +0300
Subject: [PATCH] Fix WheelTimer implementation that can expired timeout early
 (#17850)

When entries insert in the end of timer queue, then unnecessary entry
inserted (with duplicated key).
This can lead to some timeouts expired early and consume memory.
---
 changelog.d/17850.bugfix       |  1 +
 synapse/util/wheel_timer.py    |  6 ++--
 tests/util/test_wheel_timer.py | 50 ++++++++++++++++++----------------
 3 files changed, 29 insertions(+), 28 deletions(-)
 create mode 100644 changelog.d/17850.bugfix

diff --git a/changelog.d/17850.bugfix b/changelog.d/17850.bugfix
new file mode 100644
index 0000000000..8ea99c4ef9
--- /dev/null
+++ b/changelog.d/17850.bugfix
@@ -0,0 +1 @@
+Fix bug when some presence and typing timeouts can expire early.
\ No newline at end of file
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 44b109bdfd..95eb1d7185 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -47,7 +47,6 @@ class WheelTimer(Generic[T]):
         """
         self.bucket_size: int = bucket_size
         self.entries: List[_Entry[T]] = []
-        self.current_tick: int = 0
 
     def insert(self, now: int, obj: T, then: int) -> None:
         """Inserts object into timer.
@@ -78,11 +77,10 @@ class WheelTimer(Generic[T]):
                 self.entries[max(min_key, then_key) - min_key].elements.add(obj)
                 return
 
-        next_key = now_key + 1
         if self.entries:
-            last_key = self.entries[-1].end_key
+            last_key = self.entries[-1].end_key + 1
         else:
-            last_key = next_key
+            last_key = now_key + 1
 
         # Handle the case when `then` is in the past and `entries` is empty.
         then_key = max(last_key, then_key)
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 173a7cfaec..6fa575a18e 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -28,53 +28,55 @@ class WheelTimerTestCase(unittest.TestCase):
     def test_single_insert_fetch(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj = object()
-        wheel.insert(100, obj, 150)
+        wheel.insert(100, "1", 150)
 
         self.assertListEqual(wheel.fetch(101), [])
         self.assertListEqual(wheel.fetch(110), [])
         self.assertListEqual(wheel.fetch(120), [])
         self.assertListEqual(wheel.fetch(130), [])
         self.assertListEqual(wheel.fetch(149), [])
-        self.assertListEqual(wheel.fetch(156), [obj])
+        self.assertListEqual(wheel.fetch(156), ["1"])
         self.assertListEqual(wheel.fetch(170), [])
 
     def test_multi_insert(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj1 = object()
-        obj2 = object()
-        obj3 = object()
-        wheel.insert(100, obj1, 150)
-        wheel.insert(105, obj2, 130)
-        wheel.insert(106, obj3, 160)
+        wheel.insert(100, "1", 150)
+        wheel.insert(105, "2", 130)
+        wheel.insert(106, "3", 160)
 
         self.assertListEqual(wheel.fetch(110), [])
-        self.assertListEqual(wheel.fetch(135), [obj2])
+        self.assertListEqual(wheel.fetch(135), ["2"])
         self.assertListEqual(wheel.fetch(149), [])
-        self.assertListEqual(wheel.fetch(158), [obj1])
+        self.assertListEqual(wheel.fetch(158), ["1"])
         self.assertListEqual(wheel.fetch(160), [])
-        self.assertListEqual(wheel.fetch(200), [obj3])
+        self.assertListEqual(wheel.fetch(200), ["3"])
         self.assertListEqual(wheel.fetch(210), [])
 
     def test_insert_past(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj = object()
-        wheel.insert(100, obj, 50)
-        self.assertListEqual(wheel.fetch(120), [obj])
+        wheel.insert(100, "1", 50)
+        self.assertListEqual(wheel.fetch(120), ["1"])
 
     def test_insert_past_multi(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj1 = object()
-        obj2 = object()
-        obj3 = object()
-        wheel.insert(100, obj1, 150)
-        wheel.insert(100, obj2, 140)
-        wheel.insert(100, obj3, 50)
-        self.assertListEqual(wheel.fetch(110), [obj3])
+        wheel.insert(100, "1", 150)
+        wheel.insert(100, "2", 140)
+        wheel.insert(100, "3", 50)
+        self.assertListEqual(wheel.fetch(110), ["3"])
         self.assertListEqual(wheel.fetch(120), [])
-        self.assertListEqual(wheel.fetch(147), [obj2])
-        self.assertListEqual(wheel.fetch(200), [obj1])
+        self.assertListEqual(wheel.fetch(147), ["2"])
+        self.assertListEqual(wheel.fetch(200), ["1"])
         self.assertListEqual(wheel.fetch(240), [])
+
+    def test_multi_insert_then_past(self) -> None:
+        wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
+
+        wheel.insert(100, "1", 150)
+        wheel.insert(100, "2", 160)
+        wheel.insert(100, "3", 155)
+
+        self.assertListEqual(wheel.fetch(110), [])
+        self.assertListEqual(wheel.fetch(158), ["1"])
-- 
GitLab