Skip to content
Snippets Groups Projects
Commit 012b4c19 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Implement updating devices

You can update the displayname of devices now.
parent 436bffd1
Branches
Tags
No related merge requests found
......@@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler):
yield self.store.user_delete_access_tokens(user_id,
device_id=device_id)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
......
......@@ -13,19 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import logging
from synapse.http.servlet import RestServlet
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs):
......@@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(RestServlet):
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
......@@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet):
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server)
......
......@@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore):
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
device_id (str): The ID of the device to delete
Returns:
defer.Deferred
"""
......@@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore):
desc="delete_device",
)
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to update
new_display_name (str|None): new displayname for device; None
to leave unchanged
Raises:
StoreError: if the device is not found
Returns:
defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
return self._simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues=updates,
desc="update_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
......
......@@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase):
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self):
yield self._record_users()
update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update)
res = yield self.handler.get_device(user1, "abc")
self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id",
update)
@defer.inlineCallbacks
def _record_users(self):
......
......@@ -15,6 +15,7 @@
from twisted.internet import defer
import synapse.api.errors
import tests.unittest
import tests.utils
......@@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"device_id": "device2",
"display_name": "display_name 2",
}, res["device2"])
@defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device(
"user_id", "device_id", "display_name 1"
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield self.store.update_device(
"user_id", "device_id",
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
yield self.store.update_device(
"user_id", "device_id",
new_display_name="display_name 2",
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device(
"user_id", "unknown_device_id",
new_display_name="display_name 2",
)
self.assertEqual(404, cm.exception.code)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment