Skip to content
Snippets Groups Projects
Commit c23e3db5 authored by Kegan Dougal's avatar Kegan Dougal
Browse files

Add filter JSON sanity checks.

parent 8398f19b
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
class Filtering(object): class Filtering(object):
...@@ -25,10 +26,110 @@ class Filtering(object): ...@@ -25,10 +26,110 @@ class Filtering(object):
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
return self.store.get_user_filter(user_localpart, filter_id) return self.store.get_user_filter(user_localpart, filter_id)
def add_user_filter(self, user_localpart, definition): def add_user_filter(self, user_localpart, user_filter):
# TODO(paul): implement sanity checking of the definition self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, definition) return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or # TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for # replace_user_filter at some point? There's no REST API specified for
# them however # them however
def _check_valid_filter(self, user_filter):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
user_filter(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "events", "ephemeral"
]
for key in top_level_definitions:
if key in user_filter:
self._check_definition(user_filter[key])
if "room" in user_filter:
for key in room_level_definitions:
if key in user_filter["room"]:
self._check_definition(user_filter["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
try:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
except KeyError:
pass # format is optional
try:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
except KeyError:
pass # select is optional
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
...@@ -93,7 +93,7 @@ class CreateFilterRestServlet(RestServlet): ...@@ -93,7 +93,7 @@ class CreateFilterRestServlet(RestServlet):
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,
definition=content, user_filter=content,
) )
defer.returnValue((200, {"filter_id": str(filter_id)})) defer.returnValue((200, {"filter_id": str(filter_id)}))
......
...@@ -39,8 +39,8 @@ class FilteringStore(SQLBaseStore): ...@@ -39,8 +39,8 @@ class FilteringStore(SQLBaseStore):
defer.returnValue(json.loads(def_json)) defer.returnValue(json.loads(def_json))
def add_user_filter(self, user_localpart, definition): def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(definition) def_json = json.dumps(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then # Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one # INSERT a new one
......
...@@ -57,13 +57,21 @@ class FilteringTestCase(unittest.TestCase): ...@@ -57,13 +57,21 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
definition={"type": ["m.*"]}, user_filter=user_filter,
) )
self.assertEquals(filter_id, 0) self.assertEquals(filter_id, 0)
self.assertEquals({"type": ["m.*"]}, self.assertEquals(user_filter,
(yield self.datastore.get_user_filter( (yield self.datastore.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
filter_id=0, filter_id=0,
...@@ -72,9 +80,17 @@ class FilteringTestCase(unittest.TestCase): ...@@ -72,9 +80,17 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
user_filter = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
definition={"type": ["m.*"]}, user_filter=user_filter,
) )
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
...@@ -82,4 +98,4 @@ class FilteringTestCase(unittest.TestCase): ...@@ -82,4 +98,4 @@ class FilteringTestCase(unittest.TestCase):
filter_id=filter_id, filter_id=filter_id,
) )
self.assertEquals(filter, {"type": ["m.*"]}) self.assertEquals(filter, user_filter)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment