Skip to content
Snippets Groups Projects
Commit 07639c79 authored by Mark Haines's avatar Mark Haines
Browse files

Respond with more helpful error messages for unsigned requests

parent 25d80f35
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
class Codes(object): class Codes(object):
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN" FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON" BAD_JSON = "M_BAD_JSON"
NOT_JSON = "M_NOT_JSON" NOT_JSON = "M_NOT_JSON"
......
...@@ -43,7 +43,7 @@ def fetch_server_key(server_name, ssl_context_factory): ...@@ -43,7 +43,7 @@ def fetch_server_key(server_name, ssl_context_factory):
return return
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
raise IOError("Cannot get key for " % server_name) raise IOError("Cannot get key for %s" % server_name)
class SynapseKeyClientError(Exception): class SynapseKeyClientError(Exception):
...@@ -93,7 +93,7 @@ class SynapseKeyClientProtocol(HTTPClient): ...@@ -93,7 +93,7 @@ class SynapseKeyClientProtocol(HTTPClient):
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", logger.debug("Timeout waiting for response from %s",
self.transport.getHost()) self.transport.getHost())
self.on_remote_key.errback(IOError("Timeout waiting for response")) self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
......
...@@ -20,6 +20,7 @@ from syutil.crypto.signing_key import ( ...@@ -20,6 +20,7 @@ from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes is_signing_algorithm_supported, decode_verify_key_bytes
) )
from syutil.base64util import decode_base64, encode_base64 from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from OpenSSL import crypto from OpenSSL import crypto
...@@ -38,8 +39,36 @@ class Keyring(object): ...@@ -38,8 +39,36 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
verify_key = yield self.get_server_verify_key(server_name, key_ids) if not key_ids:
verify_signed_json(json_object, server_name, verify_key) raise SynapseError(
400,
"No supported algorithms in signing keys",
Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError:
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except:
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids): def get_server_verify_key(self, server_name, key_ids):
......
...@@ -233,7 +233,7 @@ class TransportLayer(object): ...@@ -233,7 +233,7 @@ class TransportLayer(object):
return (origin, key, sig) return (origin, key, sig)
except: except:
raise SynapseError( raise SynapseError(
400, "Malformed Authorization Header", Codes.FORBIDDEN 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
...@@ -246,7 +246,7 @@ class TransportLayer(object): ...@@ -246,7 +246,7 @@ class TransportLayer(object):
if not json_request["signatures"]: if not json_request["signatures"]:
raise SynapseError( raise SynapseError(
401, "Missing Authorization headers", Codes.FORBIDDEN, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request)
......
...@@ -121,7 +121,7 @@ class SQLBaseStore(object): ...@@ -121,7 +121,7 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values, or_replace=False): def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
...@@ -130,13 +130,16 @@ class SQLBaseStore(object): ...@@ -130,13 +130,16 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE or_replace : bool; if True performs an INSERT OR REPLACE
""" """
return self.runInteraction( return self.runInteraction(
self._simple_insert_txn, table, values, or_replace=or_replace self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
) )
@log_function @log_function
def _simple_insert_txn(self, txn, table, values, or_replace=False): def _simple_insert_txn(self, txn, table, values, or_replace=False,
or_ignore=False):
sql = "%s INTO %s (%s) VALUES(%s)" % ( sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else "INSERT"), ("INSERT OR REPLACE" if or_replace else
"INSERT OR IGNORE" if or_ignore else "INSERT"),
table, table,
", ".join(k for k in values), ", ".join(k for k in values),
", ".join("?" for k in values) ", ".join("?" for k in values)
......
...@@ -65,6 +65,7 @@ class KeyStore(SQLBaseStore): ...@@ -65,6 +65,7 @@ class KeyStore(SQLBaseStore):
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes), "tls_certificate": buffer(tls_certificate_bytes),
}, },
or_ignore=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
...@@ -113,4 +114,5 @@ class KeyStore(SQLBaseStore): ...@@ -113,4 +114,5 @@ class KeyStore(SQLBaseStore):
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()), "verify_key": buffer(verify_key.encode()),
}, },
or_ignore=True,
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment