diff --git a/changelog.d/17850.bugfix b/changelog.d/17850.bugfix
new file mode 100644
index 0000000000000000000000000000000000000000..8ea99c4ef9ac0f2bb38ee272d1811195f131d9e4
--- /dev/null
+++ b/changelog.d/17850.bugfix
@@ -0,0 +1 @@
+Fix bug when some presence and typing timeouts can expire early.
\ No newline at end of file
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 44b109bdfd63593cf199cadaf7c40be394813fb1..95eb1d71859028f57f88c405440fc0a931198fff 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -47,7 +47,6 @@ class WheelTimer(Generic[T]):
         """
         self.bucket_size: int = bucket_size
         self.entries: List[_Entry[T]] = []
-        self.current_tick: int = 0
 
     def insert(self, now: int, obj: T, then: int) -> None:
         """Inserts object into timer.
@@ -78,11 +77,10 @@ class WheelTimer(Generic[T]):
                 self.entries[max(min_key, then_key) - min_key].elements.add(obj)
                 return
 
-        next_key = now_key + 1
         if self.entries:
-            last_key = self.entries[-1].end_key
+            last_key = self.entries[-1].end_key + 1
         else:
-            last_key = next_key
+            last_key = now_key + 1
 
         # Handle the case when `then` is in the past and `entries` is empty.
         then_key = max(last_key, then_key)
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 173a7cfaeca2d9e5841f57b5391fe0e7ced1454c..6fa575a18e4978a7423653500a6df8c32d876632 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -28,53 +28,55 @@ class WheelTimerTestCase(unittest.TestCase):
     def test_single_insert_fetch(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj = object()
-        wheel.insert(100, obj, 150)
+        wheel.insert(100, "1", 150)
 
         self.assertListEqual(wheel.fetch(101), [])
         self.assertListEqual(wheel.fetch(110), [])
         self.assertListEqual(wheel.fetch(120), [])
         self.assertListEqual(wheel.fetch(130), [])
         self.assertListEqual(wheel.fetch(149), [])
-        self.assertListEqual(wheel.fetch(156), [obj])
+        self.assertListEqual(wheel.fetch(156), ["1"])
         self.assertListEqual(wheel.fetch(170), [])
 
     def test_multi_insert(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj1 = object()
-        obj2 = object()
-        obj3 = object()
-        wheel.insert(100, obj1, 150)
-        wheel.insert(105, obj2, 130)
-        wheel.insert(106, obj3, 160)
+        wheel.insert(100, "1", 150)
+        wheel.insert(105, "2", 130)
+        wheel.insert(106, "3", 160)
 
         self.assertListEqual(wheel.fetch(110), [])
-        self.assertListEqual(wheel.fetch(135), [obj2])
+        self.assertListEqual(wheel.fetch(135), ["2"])
         self.assertListEqual(wheel.fetch(149), [])
-        self.assertListEqual(wheel.fetch(158), [obj1])
+        self.assertListEqual(wheel.fetch(158), ["1"])
         self.assertListEqual(wheel.fetch(160), [])
-        self.assertListEqual(wheel.fetch(200), [obj3])
+        self.assertListEqual(wheel.fetch(200), ["3"])
         self.assertListEqual(wheel.fetch(210), [])
 
     def test_insert_past(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj = object()
-        wheel.insert(100, obj, 50)
-        self.assertListEqual(wheel.fetch(120), [obj])
+        wheel.insert(100, "1", 50)
+        self.assertListEqual(wheel.fetch(120), ["1"])
 
     def test_insert_past_multi(self) -> None:
         wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
 
-        obj1 = object()
-        obj2 = object()
-        obj3 = object()
-        wheel.insert(100, obj1, 150)
-        wheel.insert(100, obj2, 140)
-        wheel.insert(100, obj3, 50)
-        self.assertListEqual(wheel.fetch(110), [obj3])
+        wheel.insert(100, "1", 150)
+        wheel.insert(100, "2", 140)
+        wheel.insert(100, "3", 50)
+        self.assertListEqual(wheel.fetch(110), ["3"])
         self.assertListEqual(wheel.fetch(120), [])
-        self.assertListEqual(wheel.fetch(147), [obj2])
-        self.assertListEqual(wheel.fetch(200), [obj1])
+        self.assertListEqual(wheel.fetch(147), ["2"])
+        self.assertListEqual(wheel.fetch(200), ["1"])
         self.assertListEqual(wheel.fetch(240), [])
+
+    def test_multi_insert_then_past(self) -> None:
+        wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
+
+        wheel.insert(100, "1", 150)
+        wheel.insert(100, "2", 160)
+        wheel.insert(100, "3", 155)
+
+        self.assertListEqual(wheel.fetch(110), [])
+        self.assertListEqual(wheel.fetch(158), ["1"])