Skip to content
Snippets Groups Projects
Commit 83172487 authored by Kegan Dougal's avatar Kegan Dougal
Browse files

Add basic filtering public API unit tests. Use defers in the right places.

parent 5561a879
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
......@@ -59,19 +60,21 @@ class Filtering(object):
# replace_user_filter at some point? There's no REST API specified for
# them however
@defer.inlineCallbacks
def _filter_on_key(self, events, user, filter_id, keys):
filter_json = self.get_user_filter(user.localpart, filter_id)
filter_json = yield self.get_user_filter(user.localpart, filter_id)
if not filter_json:
return events
defer.returnValue(events)
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
defer.returnValue(self._filter_with_definition(events, definition))
except KeyError:
return events # return all events if definition isn't specified.
# return all events if definition isn't specified.
defer.returnValue(events)
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
......
......@@ -24,7 +24,7 @@ from tests.utils import (
)
from synapse.server import HomeServer
from synapse.types import UserID
user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id")
......@@ -352,6 +352,58 @@ class FilteringTestCase(unittest.TestCase):
self.filtering._passes_definition(definition, event)
)
@defer.inlineCallbacks
def test_filter_public_user_data_match(self):
user_filter = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="m.profile",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_public_user_data(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_public_user_data_no_match(self):
user_filter = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter,
)
event = MockEvent(
sender="@foo:bar",
type="custom.avatar.3d.crazy",
room_id="!foo:bar"
)
events = [event]
results = yield self.filtering.filter_public_user_data(
events=events,
user=user,
filter_id=filter_id
)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_add_filter(self):
user_filter = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment