Skip to content
Snippets Groups Projects
Commit 05c7cba7 authored by Paul "LeoNerd" Evans's avatar Paul "LeoNerd" Evans
Browse files

Initial trivial implementation of an actual 'Filtering' object; move storage...

Initial trivial implementation of an actual 'Filtering' object; move storage of user filters into there
parent f9958f34
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(paul)
_filters_for_user = {}
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.hs = hs
def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None)
if not filters or filter_id >= len(filters):
raise KeyError()
return filters[filter_id]
def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return filter_id
...@@ -28,10 +28,6 @@ import logging ...@@ -28,10 +28,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO(paul)
_filters_for_user = {}
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
...@@ -39,6 +35,7 @@ class GetFilterRestServlet(RestServlet): ...@@ -39,6 +35,7 @@ class GetFilterRestServlet(RestServlet):
super(GetFilterRestServlet, self).__init__() super(GetFilterRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id): def on_GET(self, request, user_id, filter_id):
...@@ -56,13 +53,14 @@ class GetFilterRestServlet(RestServlet): ...@@ -56,13 +53,14 @@ class GetFilterRestServlet(RestServlet):
except: except:
raise SynapseError(400, "Invalid filter_id") raise SynapseError(400, "Invalid filter_id")
filters = _filters_for_user.get(target_user.localpart, None) try:
defer.returnValue((200, self.filtering.get_user_filter(
if not filters or filter_id >= len(filters): user_localpart=target_user.localpart,
filter_id=filter_id,
)))
except KeyError:
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter")
defer.returnValue((200, filters[filter_id]))
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
...@@ -71,6 +69,7 @@ class CreateFilterRestServlet(RestServlet): ...@@ -71,6 +69,7 @@ class CreateFilterRestServlet(RestServlet):
super(CreateFilterRestServlet, self).__init__() super(CreateFilterRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
...@@ -90,10 +89,10 @@ class CreateFilterRestServlet(RestServlet): ...@@ -90,10 +89,10 @@ class CreateFilterRestServlet(RestServlet):
except: except:
raise SynapseError(400, "Invalid filter definition") raise SynapseError(400, "Invalid filter definition")
filters = _filters_for_user.setdefault(target_user.localpart, []) filter_id = self.filtering.add_user_filter(
user_localpart=target_user.localpart,
filter_id = len(filters) definition=content,
filters.append(content) )
defer.returnValue((200, {"filter_id": str(filter_id)})) defer.returnValue((200, {"filter_id": str(filter_id)}))
......
...@@ -32,6 +32,7 @@ from synapse.streams.events import EventSources ...@@ -32,6 +32,7 @@ from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
class BaseHomeServer(object): class BaseHomeServer(object):
...@@ -79,6 +80,7 @@ class BaseHomeServer(object): ...@@ -79,6 +80,7 @@ class BaseHomeServer(object):
'ratelimiter', 'ratelimiter',
'keyring', 'keyring',
'event_builder_factory', 'event_builder_factory',
'filtering',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
...@@ -197,3 +199,6 @@ class HomeServer(BaseHomeServer): ...@@ -197,3 +199,6 @@ class HomeServer(BaseHomeServer):
clock=self.get_clock(), clock=self.get_clock(),
hostname=self.hostname, hostname=self.hostname,
) )
def build_filtering(self):
return Filtering(self)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment