Skip to content
Snippets Groups Projects
Unverified Commit 0284d2a2 authored by Dirk Klimpel's avatar Dirk Klimpel Committed by GitHub
Browse files

Add new admin APIs to remove media by media ID from quarantine. (#10044)


Related to: #6681, #5956, #10040

Signed-off-by: default avatarDirk Klimpel <dirk@klimpel.org>
parent bf6fd9f4
No related branches found
No related tags found
No related merge requests found
Add new admin APIs to remove media by media ID from quarantine. Contributed by @dkimpel.
......@@ -4,6 +4,7 @@
* [List all media uploaded by a user](#list-all-media-uploaded-by-a-user)
- [Quarantine media](#quarantine-media)
* [Quarantining media by ID](#quarantining-media-by-id)
* [Remove media from quarantine by ID](#remove-media-from-quarantine-by-id)
* [Quarantining media in a room](#quarantining-media-in-a-room)
* [Quarantining all media of a user](#quarantining-all-media-of-a-user)
* [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
......@@ -77,6 +78,27 @@ Response:
{}
```
## Remove media from quarantine by ID
This API removes a single piece of local or remote media from quarantine.
Request:
```
POST /_synapse/admin/v1/media/unquarantine/<server_name>/<media_id>
{}
```
Where `server_name` is in the form of `example.org`, and `media_id` is in the
form of `abcdefg12345...`.
Response:
```json
{}
```
## Quarantining media in a room
This API quarantines all local and remote media in a room.
......
......@@ -120,6 +120,35 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
class UnquarantineMediaByID(RestServlet):
"""Quarantines local or remote media by a given ID so that no one can download
it via this server.
"""
PATTERNS = admin_patterns(
"/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info(
"Remove from quarantine local media by ID: %s/%s", server_name, media_id
)
# Remove from quarantine this media id
await self.store.quarantine_media_by_id(server_name, media_id, None)
return 200, {}
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined."""
......@@ -290,6 +319,7 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server):
PurgeMediaCacheRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
UnquarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
ProtectMediaByID(hs).register(http_server)
UnprotectMediaByID(hs).register(http_server)
......
......@@ -764,14 +764,15 @@ class RoomWorkerStore(SQLBaseStore):
self,
server_name: str,
media_id: str,
quarantined_by: str,
quarantined_by: Optional[str],
) -> int:
"""quarantines a single local or remote media id
"""quarantines or unquarantines a single local or remote media id
Args:
server_name: The name of the server that holds this media
media_id: The ID of the media to be quarantined
quarantined_by: The user ID that initiated the quarantine request
If it is `None` media will be removed from quarantine
"""
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server_name
......@@ -838,9 +839,9 @@ class RoomWorkerStore(SQLBaseStore):
txn,
local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]],
quarantined_by: str,
quarantined_by: Optional[str],
) -> int:
"""Quarantine local and remote media items
"""Quarantine and unquarantine local and remote media items
Args:
txn (cursor)
......@@ -848,18 +849,27 @@ class RoomWorkerStore(SQLBaseStore):
remote_mxcs: A list of (remote server, media id) tuples representing
remote mxc URLs
quarantined_by: The ID of the user who initiated the quarantine request
If it is `None` media will be removed from quarantine
Returns:
The total number of media items quarantined
"""
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
sql = """
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ? AND safe_from_quarantine = ?
""",
((quarantined_by, media_id, False) for media_id in local_mxcs),
)
WHERE media_id = ?
"""
# set quarantine
if quarantined_by is not None:
sql += "AND safe_from_quarantine = ?"
rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
# remove from quarantine
else:
rows = [(quarantined_by, media_id) for media_id in local_mxcs]
txn.executemany(sql, rows)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
......
......@@ -566,6 +566,134 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertFalse(os.path.exists(local_path))
class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
media_repo = hs.get_media_repository_resource()
self.store = hs.get_datastore()
self.server_name = hs.hostname
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
# Create media
upload_resource = media_repo.children[b"upload"]
# file size is 67 Byte
image_data = unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
# Upload some media into the room
response = self.helper.upload_media(
upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
self.media_id = server_and_media_id.split("/")[1]
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"])
def test_no_auth(self, action: str):
"""
Try to protect media without authentication.
"""
channel = self.make_request(
"POST",
self.url % (action, self.server_name, self.media_id),
b"{}",
)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"])
def test_requester_is_no_admin(self, action: str):
"""
If the user is not a server admin, an error is returned.
"""
self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass")
channel = self.make_request(
"POST",
self.url % (action, self.server_name, self.media_id),
access_token=self.other_user_token,
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_quarantine_media(self):
"""
Tests that quarantining and remove from quarantine a media is successfully
"""
media_info = self.get_success(self.store.get_local_media(self.media_id))
self.assertFalse(media_info["quarantined_by"])
# quarantining
channel = self.make_request(
"POST",
self.url % ("quarantine", self.server_name, self.media_id),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
self.assertTrue(media_info["quarantined_by"])
# remove from quarantine
channel = self.make_request(
"POST",
self.url % ("unquarantine", self.server_name, self.media_id),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
media_info = self.get_success(self.store.get_local_media(self.media_id))
self.assertFalse(media_info["quarantined_by"])
def test_quarantine_protected_media(self):
"""
Tests that quarantining from protected media fails
"""
# protect
self.get_success(self.store.mark_local_media_as_safe(self.media_id, safe=True))
# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
self.assertTrue(media_info["safe_from_quarantine"])
# quarantining
channel = self.make_request(
"POST",
self.url % ("quarantine", self.server_name, self.media_id),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertFalse(channel.json_body)
# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
self.assertFalse(media_info["quarantined_by"])
class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
servlets = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment