From b8b905c4ea8a0d922d34d469f7d220f53def1b53 Mon Sep 17 00:00:00 2001
From: Sean Quah <8349537+squahtx@users.noreply.github.com>
Date: Tue, 12 Oct 2021 11:24:05 +0100
Subject: [PATCH] Fix inconsistent behavior of `get_last_client_by_ip` (#10970)

Make `get_last_client_by_ip` return the same dictionary structure
regardless of whether the data has been persisted to the database.

This change will allow slightly cleaner type hints to be applied later
on.
---
 changelog.d/10970.misc                       |  1 +
 synapse/storage/databases/main/client_ips.py | 13 ++++--
 tests/storage/test_client_ips.py             | 43 ++++++++++++++++++++
 3 files changed, 53 insertions(+), 4 deletions(-)
 create mode 100644 changelog.d/10970.misc

diff --git a/changelog.d/10970.misc b/changelog.d/10970.misc
new file mode 100644
index 0000000000..bb75ea79a6
--- /dev/null
+++ b/changelog.d/10970.misc
@@ -0,0 +1 @@
+Fix inconsistent behavior of `get_last_client_by_ip` when reporting data that has not been stored in the database yet.
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c77acc7c84..6c1ef09049 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore):
         """
         ret = await super().get_last_client_ip_by_device(user_id, device_id)
 
-        # Update what is retrieved from the database with data which is pending insertion.
+        # Update what is retrieved from the database with data which is pending
+        # insertion, as if it has already been stored in the database.
         for key in self._batch_row_update:
-            uid, access_token, ip = key
+            uid, _access_token, ip = key
             if uid == user_id:
                 user_agent, did, last_seen = self._batch_row_update[key]
+
+                if did is None:
+                    # These updates don't make it to the `devices` table
+                    continue
+
                 if not device_id or did == device_id:
-                    ret[(user_id, device_id)] = {
+                    ret[(user_id, did)] = {
                         "user_id": user_id,
-                        "access_token": access_token,
                         "ip": ip,
                         "user_agent": user_agent,
                         "device_id": did,
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index dada4f98c9..0e4013ebea 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -146,6 +146,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    @parameterized.expand([(False,), (True,)])
+    def test_get_last_client_ip_by_device(self, after_persisting: bool):
+        """Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
+        self.reactor.advance(12345678)
+
+        user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(
+            self.store.store_device(
+                user_id,
+                device_id,
+                "display name",
+            )
+        )
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", device_id
+            )
+        )
+
+        if after_persisting:
+            # Trigger the storage loop
+            self.reactor.advance(10)
+
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, device_id)
+        )
+
+        self.assertEqual(
+            result,
+            {
+                (user_id, device_id): {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "ip": "ip",
+                    "user_agent": "user_agent",
+                    "last_seen": 12345678000,
+                },
+            },
+        )
+
     @parameterized.expand([(False,), (True,)])
     def test_get_user_ip_and_agents(self, after_persisting: bool):
         """Test `get_user_ip_and_agents` for persisted and unpersisted data"""
-- 
GitLab