Skip to content
Snippets Groups Projects
Unverified Commit 95e41f36 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Allow local media to be marked as safe from being quarantined. (#7718)

parent e060bf44
No related branches found
No related tags found
No related merge requests found
Media can now be marked as safe from quarantined.
......@@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = {
"account_validity": ["email_sent"],
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
}
......
......@@ -81,6 +81,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
def mark_local_media_as_safe(self, media_id: str):
"""Mark a local media as safe from quarantining."""
return self.db.simple_update_one(
table="local_media_repository",
keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True},
desc="mark_local_media_as_safe",
)
def get_url_cache(self, url, ts):
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
......
......@@ -626,36 +626,10 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
),
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
......@@ -805,17 +779,17 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
The total number of media items quarantined
"""
total_media_quarantined = 0
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
WHERE media_id = ? AND safe_from_quarantine = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
((quarantined_by, media_id, False) for media_id in local_mxcs),
)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
txn.executemany(
"""
......@@ -825,9 +799,7 @@ class RoomWorkerStore(SQLBaseStore):
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
......
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- The local_media_repository should have files which do not get quarantined,
-- e.g. files from sticker packs.
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- The local_media_repository should have files which do not get quarantined,
-- e.g. files from sticker packs.
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;
......@@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
return hs
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)
def test_quarantine_media_requires_admin(self):
self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass")
......@@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Attempt to access the media
request, channel = self.make_request(
"GET",
server_name_and_media_id,
shorthand=False,
access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_name_and_media_id
),
)
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
def test_quarantine_all_media_in_room(self, override_url_template=None):
self.register_user("room_admin", "pass", admin=True)
......@@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_and_media_id_2 = mxc_2[6:]
# Test that we cannot download any of the media anymore
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1
),
)
request, channel = self.make_request(
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_2
),
)
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
def test_quaraantine_all_media_in_room_deprecated_api_path(self):
def test_quarantine_all_media_in_room_deprecated_api_path(self):
# Perform the above test with the deprecated API path
self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
......@@ -449,25 +415,52 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Attempt to access each piece of media
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
def test_cannot_quarantine_safe_media(self):
self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass")
non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("user_nonadmin", "pass")
# Upload some media
response_1 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
# Mark the second item as safe from quarantine.
_, media_id_2 = server_and_media_id_2.split("/")
self.get_success(self.store.mark_local_media_as_safe(media_id_2))
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
request.render(self.download_resource)
self.render(request)
self.pump(1.0)
# Should be quarantined
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1,
),
json.loads(channel.result["body"].decode("utf-8")),
{"num_quarantined": 1},
"Expected 1 quarantined item",
)
# Attempt to access each piece of media, the first should fail, the
# second should succeed.
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
request, channel = self.make_request(
"GET",
......@@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
# Shouldn't be quarantined
self.assertEqual(
404,
200,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
"Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment