Skip to content
Snippets Groups Projects
Commit 2af40cfa authored by Mark Haines's avatar Mark Haines
Browse files

Merge pull request #25 from matrix-org/events_refactor

Event refactor
parents 400327d1 5a465b67
No related branches found
No related tags found
No related merge requests found
Showing
with 1222 additions and 686 deletions
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
import pydot
import cgi
import json
import datetime
import argparse
from synapse.events import FrozenEvent
def make_graph(db_name, room_id, file_prefix):
conn = sqlite3.connect(db_name)
c = conn.execute(
"SELECT json FROM event_json where room_id = ?",
(room_id,)
)
events = [FrozenEvent(json.loads(e[0])) for e in c.fetchall()]
events.sort(key=lambda e: e.depth)
node_map = {}
state_groups = {}
graph = pydot.Dot(graph_name="Test")
for event in events:
c = conn.execute(
"SELECT state_group FROM event_to_state_groups "
"WHERE event_id = ?",
(event.event_id,)
)
res = c.fetchone()
state_group = res[0] if res else None
if state_group is not None:
state_groups.setdefault(state_group, []).append(event.event_id)
t = datetime.datetime.fromtimestamp(
float(event.origin_server_ts) / 1000
).strftime('%Y-%m-%d %H:%M:%S,%f')
content = json.dumps(event.get_dict()["content"])
label = (
"<"
"<b>%(name)s </b><br/>"
"Type: <b>%(type)s </b><br/>"
"State key: <b>%(state_key)s </b><br/>"
"Content: <b>%(content)s </b><br/>"
"Time: <b>%(time)s </b><br/>"
"Depth: <b>%(depth)s </b><br/>"
"State group: %(state_group)s<br/>"
">"
) % {
"name": event.event_id,
"type": event.type,
"state_key": event.get("state_key", None),
"content": cgi.escape(content, quote=True),
"time": t,
"depth": event.depth,
"state_group": state_group,
}
node = pydot.Node(
name=event.event_id,
label=label,
)
node_map[event.event_id] = node
graph.add_node(node)
for event in events:
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
except:
end_node = pydot.Node(
name=prev_id,
label="<<b>%s</b>>" % (prev_id,),
)
node_map[prev_id] = end_node
graph.add_node(end_node)
edge = pydot.Edge(node_map[event.event_id], end_node)
graph.add_edge(edge)
for group, event_ids in state_groups.items():
if len(event_ids) <= 1:
continue
cluster = pydot.Cluster(
str(group),
label="<State Group: %s>" % (str(group),)
)
for event_id in event_ids:
cluster.add_node(node_map[event_id])
graph.add_subgraph(cluster)
graph.write('%s.dot' % file_prefix, format='raw', prog='dot')
graph.write_svg("%s.svg" % file_prefix, prog='dot')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate a PDU graph for a given room by talking "
"to the given homeserver to get the list of PDUs. \n"
"Requires pydot."
)
parser.add_argument(
"-p", "--prefix", dest="prefix",
help="String to prefix output files with"
)
parser.add_argument('db')
parser.add_argument('room')
args = parser.parse_args()
make_graph(args.db, args.room, args.prefix)
...@@ -18,6 +18,9 @@ class dictobj(dict): ...@@ -18,6 +18,9 @@ class dictobj(dict):
def get_full_dict(self): def get_full_dict(self):
return dict(self) return dict(self)
def get_pdu_json(self):
return dict(self)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore
from synapse.storage.event_federation import EventFederationStore
from syutil.base64util import encode_base64, decode_base64
from synapse.crypto.event_signing import compute_event_signature
from synapse.events.builder import EventBuilder
from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash
from syutil.crypto.jsonsign import (
verify_signed_json, SignatureVerifyException,
)
from syutil.crypto.signing_key import decode_verify_key_bytes
from syutil.jsonutil import encode_canonical_json
import argparse
import dns.resolver
import hashlib
import json
import sqlite3
import syutil
import urllib2
delta_sql = """
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
"""
class Store(object):
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
def _generate_event_json(self, txn, rows):
events = []
for row in rows:
d = dict(row)
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
if "origin_server_ts" not in d:
d["origin_server_ts"] = d.pop("ts", 0)
else:
d.pop("ts", 0)
d.pop("prev_state", None)
d.update(json.loads(d.pop("unrecognized_keys")))
d["sender"] = d.pop("user_id")
d["content"] = json.loads(d["content"])
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
outlier = d.pop("outlier", False)
# d.pop("membership", None)
d.pop("state_hash", None)
d.pop("replaces_state", None)
b = EventBuilder(d)
b.internal_metadata.outlier = outlier
events.append(b)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
hashes = dict(ev.auth_events)
for e_id, hash in ev.prev_events:
if e_id in hashes and not hash:
hash.update(hashes[e_id])
#
# if hasattr(ev, "state_key"):
# ev.prev_state = [
# (e_id, h)
# for e_id, h, is_state in prevs
# if is_state == 1
# ]
return [e.build() for e in events]
store = Store()
def get_key(server_name):
print "Getting keys for: %s" % (server_name,)
targets = []
if ":" in server_name:
target, port = server_name.split(":")
targets.append((target, int(port)))
return
try:
answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
for srv in answers:
targets.append((srv.target, srv.port))
except dns.resolver.NXDOMAIN:
targets.append((server_name, 8448))
except:
print "Failed to lookup keys for %s" % (server_name,)
return {}
for target, port in targets:
url = "https://%s:%i/_matrix/key/v1" % (target, port)
try:
keys = json.load(urllib2.urlopen(url, timeout=2))
verify_keys = {}
for key_id, key_base64 in keys["verify_keys"].items():
verify_key = decode_verify_key_bytes(
key_id, decode_base64(key_base64)
)
verify_signed_json(keys, server_name, verify_key)
verify_keys[key_id] = verify_key
print "Got keys for: %s" % (server_name,)
return verify_keys
except urllib2.URLError:
pass
print "Failed to get keys for %s" % (server_name,)
return {}
def reinsert_events(cursor, server_name, signing_key):
cursor.executescript(delta_sql)
cursor.execute(
"SELECT * FROM events ORDER BY rowid ASC"
)
rows = store.cursor_to_dict(cursor)
events = store._generate_event_json(cursor, rows)
print "Got events from DB."
algorithms = {
"sha256": hashlib.sha256,
}
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_key = signing_key.verify_key
verify_key.alg = signing_key.alg
verify_key.version = signing_key.version
server_keys = {
server_name: {
key_id: verify_key
}
}
for event in events:
for alg_name in event.hashes:
if check_event_content_hash(event, algorithms[alg_name]):
pass
else:
pass
print "FAIL content hash %s %s" % (alg_name, event.event_id, )
have_own_correctly_signed = False
for host, sigs in event.signatures.items():
pruned = prune_event(event)
for key_id in sigs:
if host not in server_keys:
server_keys[host] = get_key(host)
if key_id in server_keys[host]:
try:
verify_signed_json(
pruned.get_pdu_json(),
host,
server_keys[host][key_id]
)
if host == server_name:
have_own_correctly_signed = True
except SignatureVerifyException:
print "FAIL signature check %s %s" % (
key_id, event.event_id
)
# TODO: Re sign with our own server key
if not have_own_correctly_signed:
sigs = compute_event_signature(event, server_name, signing_key)
event.signatures.update(sigs)
pruned = prune_event(event)
for key_id in event.signatures[server_name]:
verify_signed_json(
pruned.get_pdu_json(),
server_name,
server_keys[server_name][key_id]
)
event_json = encode_canonical_json(
event.get_dict()
).decode("UTF-8")
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
store._simple_insert_txn(
cursor,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": event_json,
},
or_replace=True,
)
def main(database, server_name, signing_key):
conn = sqlite3.connect(database)
cursor = conn.cursor()
reinsert_events(cursor, server_name, signing_key)
conn.commit()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("database")
parser.add_argument("server_name")
parser.add_argument(
"signing_key", type=argparse.FileType('r'),
)
args = parser.parse_args()
signing_key = syutil.crypto.signing_key.read_signing_keys(
args.signing_key
)
main(args.database, args.server_name, signing_key[0])
...@@ -41,6 +41,7 @@ setup( ...@@ -41,6 +41,7 @@ setup(
"pynacl", "pynacl",
"daemonize", "daemonize",
"py-bcrypt", "py-bcrypt",
"frozendict>=0.4",
"pillow", "pillow",
], ],
dependency_links=[ dependency_links=[
......
...@@ -17,14 +17,10 @@ ...@@ -17,14 +17,10 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
RoomJoinRulesEvent, RoomCreateEvent, RoomAliasesEvent,
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from syutil.base64util import encode_base64 from synapse.util.async import run_on_reactor
import logging import logging
...@@ -53,15 +49,17 @@ class Auth(object): ...@@ -53,15 +49,17 @@ class Auth(object):
logger.warn("Trusting event: %s", event.event_id) logger.warn("Trusting event: %s", event.event_id)
return True return True
if event.type == RoomCreateEvent.TYPE: if event.type == EventTypes.Create:
# FIXME # FIXME
return True return True
# FIXME: Temp hack # FIXME: Temp hack
if event.type == RoomAliasesEvent.TYPE: if event.type == EventTypes.Aliases:
return True return True
if event.type == RoomMemberEvent.TYPE: logger.debug("Auth events: %s", auth_events)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed( allowed = self.is_membership_change_allowed(
event, auth_events event, auth_events
) )
...@@ -74,10 +72,10 @@ class Auth(object): ...@@ -74,10 +72,10 @@ class Auth(object):
self.check_event_sender_in_room(event, auth_events) self.check_event_sender_in_room(event, auth_events)
self._can_send_event(event, auth_events) self._can_send_event(event, auth_events)
if event.type == RoomPowerLevelsEvent.TYPE: if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events) self._check_power_levels(event, auth_events)
if event.type == RoomRedactionEvent.TYPE: if event.type == EventTypes.Redaction:
self._check_redaction(event, auth_events) self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
...@@ -93,7 +91,7 @@ class Auth(object): ...@@ -93,7 +91,7 @@ class Auth(object):
def check_joined_room(self, room_id, user_id): def check_joined_room(self, room_id, user_id):
member = yield self.state.get_current_state( member = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
event_type=RoomMemberEvent.TYPE, event_type=EventTypes.Member,
state_key=user_id state_key=user_id
) )
self._check_joined_room(member, user_id, room_id) self._check_joined_room(member, user_id, room_id)
...@@ -104,7 +102,7 @@ class Auth(object): ...@@ -104,7 +102,7 @@ class Auth(object):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
for event in curr_state: for event in curr_state:
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
try: try:
if self.hs.parse_userid(event.state_key).domain != host: if self.hs.parse_userid(event.state_key).domain != host:
continue continue
...@@ -118,7 +116,7 @@ class Auth(object): ...@@ -118,7 +116,7 @@ class Auth(object):
defer.returnValue(False) defer.returnValue(False)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key) member_event = auth_events.get(key)
return self._check_joined_room( return self._check_joined_room(
...@@ -140,7 +138,7 @@ class Auth(object): ...@@ -140,7 +138,7 @@ class Auth(object):
# Check if this is the room creator joining: # Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership: if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event: # Get room creation event:
key = (RoomCreateEvent.TYPE, "", ) key = (EventTypes.Create, "", )
create = auth_events.get(key) create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id: if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key: if create.content["creator"] == event.state_key:
...@@ -149,19 +147,19 @@ class Auth(object): ...@@ -149,19 +147,19 @@ class Auth(object):
target_user_id = event.state_key target_user_id = event.state_key
# get info about the caller # get info about the caller
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key) caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target # get info about the target
key = (RoomMemberEvent.TYPE, target_user_id, ) key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key) target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN target_in_room = target and target.membership == Membership.JOIN
key = (RoomJoinRulesEvent.TYPE, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key) join_rule_event = auth_events.get(key)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get( join_rule = join_rule_event.content.get(
...@@ -256,7 +254,7 @@ class Auth(object): ...@@ -256,7 +254,7 @@ class Auth(object):
return True return True
def _get_power_level_from_event_state(self, event, user_id, auth_events): def _get_power_level_from_event_state(self, event, user_id, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key) power_level_event = auth_events.get(key)
level = None level = None
if power_level_event: if power_level_event:
...@@ -264,7 +262,7 @@ class Auth(object): ...@@ -264,7 +262,7 @@ class Auth(object):
if not level: if not level:
level = power_level_event.content.get("users_default", 0) level = power_level_event.content.get("users_default", 0)
else: else:
key = (RoomCreateEvent.TYPE, "", ) key = (EventTypes.Create, "", )
create_event = auth_events.get(key) create_event = auth_events.get(key)
if (create_event is not None and if (create_event is not None and
create_event.content["creator"] == user_id): create_event.content["creator"] == user_id):
...@@ -273,7 +271,7 @@ class Auth(object): ...@@ -273,7 +271,7 @@ class Auth(object):
return level return level
def _get_ops_level_from_event_state(self, event, auth_events): def _get_ops_level_from_event_state(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key) power_level_event = auth_events.get(key)
if power_level_event: if power_level_event:
...@@ -351,29 +349,31 @@ class Auth(object): ...@@ -351,29 +349,31 @@ class Auth(object):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, event): def add_auth_events(self, builder, context):
if event.type == RoomCreateEvent.TYPE: yield run_on_reactor()
event.auth_events = []
if builder.type == EventTypes.Create:
builder.auth_events = []
return return
auth_events = [] auth_ids = []
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = event.old_state_events.get(key) power_level_event = context.current_state.get(key)
if power_level_event: if power_level_event:
auth_events.append(power_level_event.event_id) auth_ids.append(power_level_event.event_id)
key = (RoomJoinRulesEvent.TYPE, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = event.old_state_events.get(key) join_rule_event = context.current_state.get(key)
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (EventTypes.Member, builder.user_id, )
member_event = event.old_state_events.get(key) member_event = context.current_state.get(key)
key = (RoomCreateEvent.TYPE, "", ) key = (EventTypes.Create, "", )
create_event = event.old_state_events.get(key) create_event = context.current_state.get(key)
if create_event: if create_event:
auth_events.append(create_event.event_id) auth_ids.append(create_event.event_id)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
...@@ -381,33 +381,37 @@ class Auth(object): ...@@ -381,33 +381,37 @@ class Auth(object):
else: else:
is_public = False is_public = False
if event.type == RoomMemberEvent.TYPE: if builder.type == EventTypes.Member:
e_type = event.content["membership"] e_type = builder.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event: if join_rule_event:
auth_events.append(join_rule_event.event_id) auth_ids.append(join_rule_event.event_id)
if e_type == Membership.JOIN:
if member_event and not is_public: if member_event and not is_public:
auth_events.append(member_event.event_id) auth_ids.append(member_event.event_id)
else:
if member_event:
auth_ids.append(member_event.event_id)
elif member_event: elif member_event:
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_events.append(member_event.event_id) auth_ids.append(member_event.event_id)
hashes = yield self.store.get_event_reference_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_events auth_ids
) )
hashes = [
{ builder.auth_events = auth_events_entries
k: encode_base64(v) for k, v in h.items()
if k == "sha256" context.auth_events = {
} k: v
for h in hashes for k, v in context.current_state.items()
] if v.event_id in auth_ids
event.auth_events = zip(auth_events, hashes) }
@log_function @log_function
def _can_send_event(self, event, auth_events): def _can_send_event(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key) send_level_event = auth_events.get(key)
send_level = None send_level = None
if send_level_event: if send_level_event:
......
...@@ -59,3 +59,18 @@ class LoginType(object): ...@@ -59,3 +59,18 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url" EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
class EventTypes(object):
Member = "m.room.member"
Create = "m.room.create"
JoinRules = "m.room.join_rules"
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
# These are used for validation
Message = "m.room.message"
Topic = "m.room.topic"
Name = "m.room.name"
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.jsonobject import JsonEncodedObject
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, SynapseEvent):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d:
d["age"] = int(hs.get_clock().time_msec()) - d["age_ts"]
del d["age_ts"]
return d
class SynapseEvent(JsonEncodedObject):
"""Base class for Synapse events. These are JSON objects which must abide
by a certain well-defined structure.
"""
# Attributes that are currently assumed by the federation side:
# Mandatory:
# - event_id
# - room_id
# - type
# - is_state
#
# Optional:
# - state_key (mandatory when is_state is True)
# - prev_events (these can be filled out by the federation layer itself.)
# - prev_state
valid_keys = [
"event_id",
"type",
"room_id",
"user_id", # sender/initiator
"content", # HTTP body, JSON
"state_key",
"age_ts",
"prev_content",
"replaces_state",
"redacted_because",
"origin_server_ts",
]
internal_keys = [
"is_state",
"depth",
"destinations",
"origin",
"outlier",
"redacted",
"prev_events",
"hashes",
"signatures",
"prev_state",
"auth_events",
"state_hash",
]
required_keys = [
"event_id",
"room_id",
"content",
]
outlier = False
def __init__(self, raises=True, **kwargs):
super(SynapseEvent, self).__init__(**kwargs)
# if "content" in kwargs:
# self.check_json(self.content, raises=raises)
def get_content_template(self):
""" Retrieve the JSON template for this event as a dict.
The template must be a dict representing the JSON to match. Only
required keys should be present. The values of the keys in the template
are checked via type() to the values of the same keys in the actual
event JSON.
NB: If loading content via json.loads, you MUST define strings as
unicode.
For example:
Content:
{
"name": u"bob",
"age": 18,
"friends": [u"mike", u"jill"]
}
Template:
{
"name": u"string",
"age": 0,
"friends": [u"string"]
}
The values "string" and 0 could be anything, so long as the types
are the same as the content.
"""
raise NotImplementedError("get_content_template not implemented.")
def get_pdu_json(self, time_now=None):
pdu_json = self.get_full_dict()
pdu_json.pop("destinations", None)
pdu_json.pop("outlier", None)
pdu_json.pop("replaces_state", None)
pdu_json.pop("redacted", None)
pdu_json.pop("prev_content", None)
state_hash = pdu_json.pop("state_hash", None)
if state_hash is not None:
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash
content = pdu_json.get("content", {})
content.pop("prev", None)
if time_now is not None and "age_ts" in pdu_json:
age = time_now - pdu_json["age_ts"]
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["age_ts"]
user_id = pdu_json.pop("user_id")
pdu_json["sender"] = user_id
return pdu_json
class SynapseStateEvent(SynapseEvent):
def __init__(self, **kwargs):
if "state_key" not in kwargs:
kwargs["state_key"] = ""
super(SynapseStateEvent, self).__init__(**kwargs)
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import Feedback, Membership
from synapse.api.errors import SynapseError
from . import SynapseEvent, SynapseStateEvent
class GenericEvent(SynapseEvent):
def get_content_template(self):
return {}
class RoomTopicEvent(SynapseEvent):
TYPE = "m.room.topic"
internal_keys = SynapseEvent.internal_keys + [
"topic",
]
def __init__(self, **kwargs):
kwargs["state_key"] = ""
if "topic" in kwargs["content"]:
kwargs["topic"] = kwargs["content"]["topic"]
super(RoomTopicEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"topic": u"string"}
class RoomNameEvent(SynapseEvent):
TYPE = "m.room.name"
internal_keys = SynapseEvent.internal_keys + [
"name",
]
def __init__(self, **kwargs):
kwargs["state_key"] = ""
if "name" in kwargs["content"]:
kwargs["name"] = kwargs["content"]["name"]
super(RoomNameEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"name": u"string"}
class RoomMemberEvent(SynapseEvent):
TYPE = "m.room.member"
valid_keys = SynapseEvent.valid_keys + [
# target is the state_key
"membership", # action
]
def __init__(self, **kwargs):
if "membership" not in kwargs:
kwargs["membership"] = kwargs.get("content", {}).get("membership")
if not kwargs["membership"] in Membership.LIST:
raise SynapseError(400, "Bad membership value.")
super(RoomMemberEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"membership": u"string"}
class MessageEvent(SynapseEvent):
TYPE = "m.room.message"
valid_keys = SynapseEvent.valid_keys + [
"msg_id", # unique per room + user combo
]
def __init__(self, **kwargs):
super(MessageEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"msgtype": u"string"}
class FeedbackEvent(SynapseEvent):
TYPE = "m.room.message.feedback"
valid_keys = SynapseEvent.valid_keys
def __init__(self, **kwargs):
super(FeedbackEvent, self).__init__(**kwargs)
if not kwargs["content"]["type"] in Feedback.LIST:
raise SynapseError(400, "Bad feedback value.")
def get_content_template(self):
return {
"type": u"string",
"target_event_id": u"string"
}
class InviteJoinEvent(SynapseEvent):
TYPE = "m.room.invite_join"
valid_keys = SynapseEvent.valid_keys + [
# target_user_id is the state_key
"target_host",
]
def __init__(self, **kwargs):
super(InviteJoinEvent, self).__init__(**kwargs)
def get_content_template(self):
return {}
class RoomConfigEvent(SynapseEvent):
TYPE = "m.room.config"
def __init__(self, **kwargs):
kwargs["state_key"] = ""
super(RoomConfigEvent, self).__init__(**kwargs)
def get_content_template(self):
return {}
class RoomCreateEvent(SynapseStateEvent):
TYPE = "m.room.create"
def get_content_template(self):
return {}
class RoomJoinRulesEvent(SynapseStateEvent):
TYPE = "m.room.join_rules"
def get_content_template(self):
return {}
class RoomPowerLevelsEvent(SynapseStateEvent):
TYPE = "m.room.power_levels"
def get_content_template(self):
return {}
class RoomAliasesEvent(SynapseStateEvent):
TYPE = "m.room.aliases"
def get_content_template(self):
return {}
class RoomRedactionEvent(SynapseEvent):
TYPE = "m.room.redaction"
valid_keys = SynapseEvent.valid_keys + ["redacts"]
def get_content_template(self):
return {}
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError, Codes
class EventValidator(object):
def __init__(self, hs):
pass
def validate(self, event):
"""Checks the given JSON content abides by the rules of the template.
Args:
content : A JSON object to check.
raises: True to raise a SynapseError if the check fails.
Returns:
True if the content passes the template. Returns False if the check
fails and raises=False.
Raises:
SynapseError if the check fails and raises=True.
"""
# recursively call to inspect each layer
err_msg = self._check_json_template(
event.content,
event.get_content_template()
)
if err_msg:
raise SynapseError(400, err_msg, Codes.BAD_JSON)
else:
return True
def _check_json_template(self, content, template):
"""Check content and template matches.
If the template is a dict, each key in the dict will be validated with
the content, else it will just compare the types of content and
template. This basic type check is required because this function will
be recursively called and could be called with just strs or ints.
Args:
content: The content to validate.
template: The validation template.
Returns:
str: An error message if the validation fails, else None.
"""
if type(content) != type(template):
return "Mismatched types: %s" % template
if type(template) == dict:
for key in template:
if key not in content:
return "Missing %s key" % key
if type(content[key]) != type(template[key]):
return "Key %s is of the wrong type (got %s, want %s)" % (
key, type(content[key]), type(template[key]))
if type(content[key]) == dict:
# we must go deeper
msg = self._check_json_template(
content[key],
template[key]
)
if msg:
return msg
elif type(content[key]) == list:
# make sure each item type in content matches the template
for entry in content[key]:
msg = self._check_json_template(
entry,
template[key][0]
)
if msg:
return msg
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.events.utils import prune_event from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64 from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json from syutil.crypto.jsonsign import sign_json
...@@ -29,17 +29,17 @@ logger = logging.getLogger(__name__) ...@@ -29,17 +29,17 @@ logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256): def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents""" """Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_content_hash(event, hash_algorithm) name, expected_hash = compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest())) logger.debug("Expecting hash: %s", encode_base64(expected_hash))
if computed_hash.name not in event.hashes: if name not in event.hashes:
raise SynapseError( raise SynapseError(
400, 400,
"Algorithm %s not in hashes %s" % ( "Algorithm %s not in hashes %s" % (
computed_hash.name, list(event.hashes), name, list(event.hashes),
), ),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
message_hash_base64 = event.hashes[computed_hash.name] message_hash_base64 = event.hashes[name]
try: try:
message_hash_bytes = decode_base64(message_hash_base64) message_hash_bytes = decode_base64(message_hash_base64)
except: except:
...@@ -48,10 +48,10 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): ...@@ -48,10 +48,10 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"Invalid base64: %s" % (message_hash_base64,), "Invalid base64: %s" % (message_hash_base64,),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
return message_hash_bytes == computed_hash.digest() return message_hash_bytes == expected_hash
def _compute_content_hash(event, hash_algorithm): def compute_content_hash(event, hash_algorithm):
event_json = event.get_pdu_json() event_json = event.get_pdu_json()
event_json.pop("age_ts", None) event_json.pop("age_ts", None)
event_json.pop("unsigned", None) event_json.pop("unsigned", None)
...@@ -59,8 +59,11 @@ def _compute_content_hash(event, hash_algorithm): ...@@ -59,8 +59,11 @@ def _compute_content_hash(event, hash_algorithm):
event_json.pop("hashes", None) event_json.pop("hashes", None)
event_json.pop("outlier", None) event_json.pop("outlier", None)
event_json.pop("destinations", None) event_json.pop("destinations", None)
event_json_bytes = encode_canonical_json(event_json) event_json_bytes = encode_canonical_json(event_json)
return hash_algorithm(event_json_bytes)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
...@@ -79,27 +82,28 @@ def compute_event_signature(event, signature_name, signing_key): ...@@ -79,27 +82,28 @@ def compute_event_signature(event, signature_name, signing_key):
redact_json = tmp_event.get_pdu_json() redact_json = tmp_event.get_pdu_json()
redact_json.pop("age_ts", None) redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None) redact_json.pop("unsigned", None)
logger.debug("Signing event: %s", redact_json) logger.debug("Signing event: %s", encode_canonical_json(redact_json))
redact_json = sign_json(redact_json, signature_name, signing_key) redact_json = sign_json(redact_json, signature_name, signing_key)
logger.debug("Signed event: %s", encode_canonical_json(redact_json))
return redact_json["signatures"] return redact_json["signatures"]
def add_hashes_and_signatures(event, signature_name, signing_key, def add_hashes_and_signatures(event, signature_name, signing_key,
hash_algorithm=hashlib.sha256): hash_algorithm=hashlib.sha256):
if hasattr(event, "old_state_events"): # if hasattr(event, "old_state_events"):
state_json_bytes = encode_canonical_json( # state_json_bytes = encode_canonical_json(
[e.event_id for e in event.old_state_events.values()] # [e.event_id for e in event.old_state_events.values()]
) # )
hashed = hash_algorithm(state_json_bytes) # hashed = hash_algorithm(state_json_bytes)
event.state_hash = { # event.state_hash = {
hashed.name: encode_base64(hashed.digest()) # hashed.name: encode_base64(hashed.digest())
} # }
hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm) name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm)
if not hasattr(event, "hashes"): if not hasattr(event, "hashes"):
event.hashes = {} event.hashes = {}
event.hashes[hashed.name] = encode_base64(hashed.digest()) event.hashes[name] = encode_base64(digest)
event.signatures = compute_event_signature( event.signatures = compute_event_signature(
event, event,
......
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.frozenutils import freeze, unfreeze
import copy
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = copy.deepcopy(internal_metadata_dict)
def get_dict(self):
return dict(self.__dict__)
def is_outlier(self):
return hasattr(self, "outlier") and self.outlier
def _event_dict_property(key):
def getter(self):
return self._event_dict[key]
def setter(self, v):
self._event_dict[key] = v
def delete(self):
del self._event_dict[key]
return property(
getter,
setter,
delete,
)
class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={},
internal_metadata_dict={}):
self.signatures = copy.deepcopy(signatures)
self.unsigned = copy.deepcopy(unsigned)
self._event_dict = copy.deepcopy(event_dict)
self.internal_metadata = _EventInternalMetadata(
internal_metadata_dict
)
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts")
prev_events = _event_dict_property("prev_events")
prev_state = _event_dict_property("prev_state")
redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
def membership(self):
return self.content["membership"]
def is_state(self):
return hasattr(self, "state_key")
def get_dict(self):
d = dict(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": self.unsigned,
})
return d
def get(self, key, default):
return self._event_dict.get(key, default)
def get_internal_metadata_dict(self):
return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now=None):
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
age = time_now - pdu_json["unsigned"]["age_ts"]
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["unsigned"]["age_ts"]
return pdu_json
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}):
event_dict = copy.deepcopy(event_dict)
signatures = copy.deepcopy(event_dict.pop("signatures", {}))
unsigned = copy.deepcopy(event_dict.pop("unsigned", {}))
frozen_dict = freeze(event_dict)
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
)
@staticmethod
def from_event(event):
e = FrozenEvent(
event.get_pdu_json()
)
e.internal_metadata = event.internal_metadata
return e
def get_dict(self):
# We need to unfreeze what we return
return unfreeze(super(FrozenEvent, self).get_dict())
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
self.event_id, self.type, self.get("state_key", None),
)
...@@ -13,42 +13,40 @@ ...@@ -13,42 +13,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.events.room import ( from . import EventBase, FrozenEvent
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
RoomPowerLevelsEvent, RoomJoinRulesEvent,
RoomCreateEvent,
RoomRedactionEvent,
)
from synapse.types import EventID from synapse.types import EventID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
import copy
class EventFactory(object):
_event_classes = [ class EventBuilder(EventBase):
RoomTopicEvent, def __init__(self, key_values={}):
RoomNameEvent, signatures = copy.deepcopy(key_values.pop("signatures", {}))
MessageEvent, unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
RoomMemberEvent,
FeedbackEvent,
InviteJoinEvent,
RoomConfigEvent,
RoomPowerLevelsEvent,
RoomJoinRulesEvent,
RoomCreateEvent,
RoomRedactionEvent,
]
def __init__(self, hs): super(EventBuilder, self).__init__(
self._event_list = {} # dict of TYPE to event class key_values,
for event_class in EventFactory._event_classes: signatures=signatures,
self._event_list[event_class.TYPE] = event_class unsigned=unsigned
)
self.clock = hs.get_clock() def update_event_key(self, key, value):
self.hs = hs self._event_dict[key] = value
def update_event_keys(self, other_dict):
self._event_dict.update(other_dict)
def build(self):
return FrozenEvent.from_event(self)
class EventBuilderFactory(object):
def __init__(self, clock, hostname):
self.clock = clock
self.hostname = hostname
self.event_id_count = 0 self.event_id_count = 0
...@@ -58,33 +56,22 @@ class EventFactory(object): ...@@ -58,33 +56,22 @@ class EventFactory(object):
local_part = str(int(self.clock.time())) + i + random_string(5) local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID.create_local(local_part, self.hs) e_id = EventID.create(local_part, self.hostname)
return e_id.to_string() return e_id.to_string()
def create_event(self, etype=None, **kwargs): def new(self, key_values={}):
kwargs["type"] = etype key_values["event_id"] = self.create_event_id()
if "event_id" not in kwargs:
kwargs["event_id"] = self.create_event_id() time_now = int(self.clock.time_msec())
kwargs["origin"] = self.hs.hostname
else: key_values.setdefault("origin", self.hostname)
ev_id = self.hs.parse_eventid(kwargs["event_id"]) key_values.setdefault("origin_server_ts", time_now)
kwargs["origin"] = ev_id.domain
key_values.setdefault("unsigned", {})
if "origin_server_ts" not in kwargs: age = key_values["unsigned"].pop("age", 0)
kwargs["origin_server_ts"] = int(self.clock.time_msec()) key_values["unsigned"].setdefault("age_ts", time_now - age)
# The "age" key is a delta timestamp that should be converted into an key_values["signatures"] = {}
# absolute timestamp the minute we see it.
if "age" in kwargs: return EventBuilder(key_values=key_values,)
kwargs["age_ts"] = int(self.clock.time_msec()) - int(kwargs["age"])
del kwargs["age"]
elif "age_ts" not in kwargs:
kwargs["age_ts"] = int(self.clock.time_msec())
if etype in self._event_list:
handler = self._event_list[etype]
else:
handler = GenericEvent
return handler(**kwargs)
...@@ -13,3 +13,10 @@ ...@@ -13,3 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class EventContext(object):
def __init__(self, current_state=None, auth_events=None):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None
...@@ -13,10 +13,8 @@ ...@@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .room import ( from synapse.api.constants import EventTypes
RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent, from . import EventBase
RoomAliasesEvent, RoomCreateEvent,
)
def prune_event(event): def prune_event(event):
...@@ -31,7 +29,7 @@ def prune_event(event): ...@@ -31,7 +29,7 @@ def prune_event(event):
allowed_keys = [ allowed_keys = [
"event_id", "event_id",
"user_id", "sender",
"room_id", "room_id",
"hashes", "hashes",
"signatures", "signatures",
...@@ -44,6 +42,7 @@ def prune_event(event): ...@@ -44,6 +42,7 @@ def prune_event(event):
"auth_events", "auth_events",
"origin", "origin",
"origin_server_ts", "origin_server_ts",
"membership",
] ]
new_content = {} new_content = {}
...@@ -53,13 +52,13 @@ def prune_event(event): ...@@ -53,13 +52,13 @@ def prune_event(event):
if field in event.content: if field in event.content:
new_content[field] = event.content[field] new_content[field] = event.content[field]
if event_type == RoomMemberEvent.TYPE: if event_type == EventTypes.Member:
add_fields("membership") add_fields("membership")
elif event_type == RoomCreateEvent.TYPE: elif event_type == EventTypes.Create:
add_fields("creator") add_fields("creator")
elif event_type == RoomJoinRulesEvent.TYPE: elif event_type == EventTypes.JoinRules:
add_fields("join_rule") add_fields("join_rule")
elif event_type == RoomPowerLevelsEvent.TYPE: elif event_type == EventTypes.PowerLevels:
add_fields( add_fields(
"users", "users",
"users_default", "users_default",
...@@ -71,15 +70,61 @@ def prune_event(event): ...@@ -71,15 +70,61 @@ def prune_event(event):
"kick", "kick",
"redact", "redact",
) )
elif event_type == RoomAliasesEvent.TYPE: elif event_type == EventTypes.Aliases:
add_fields("aliases") add_fields("aliases")
allowed_fields = { allowed_fields = {
k: v k: v
for k, v in event.get_full_dict().items() for k, v in event.get_dict().items()
if k in allowed_keys if k in allowed_keys
} }
allowed_fields["content"] = new_content allowed_fields["content"] = new_content
return type(event)(**allowed_fields) allowed_fields["unsigned"] = {}
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
return type(event)(allowed_fields)
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec())
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
d["user_id"] = d.pop("sender", None)
if "redacted_because" in e.unsigned:
d["redacted_because"] = serialize_event(
hs, e.unsigned["redacted_because"]
)
del d["unsigned"]["redacted_because"]
if "redacted_by" in e.unsigned:
d["redacted_by"] = e.unsigned["redacted_by"]
del d["unsigned"]["redacted_by"]
if "replaces_state" in e.unsigned:
d["replaces_state"] = e.unsigned["replaces_state"]
del d["unsigned"]["replaces_state"]
if "prev_content" in e.unsigned:
d["prev_content"] = e.unsigned["prev_content"]
del d["unsigned"]["prev_content"]
del d["auth_events"]
del d["prev_events"]
del d["hashes"]
del d["signatures"]
return d
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.types import EventID, RoomID, UserID
from synapse.api.errors import SynapseError
from synapse.api.constants import EventTypes, Membership
class EventValidator(object):
def validate(self, event):
EventID.from_string(event.event_id)
RoomID.from_string(event.room_id)
required = [
# "auth_events",
"content",
# "hashes",
"origin",
# "prev_events",
"sender",
"type",
]
for k in required:
if not hasattr(event, k):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
strings = [
"origin",
"sender",
"type",
]
if hasattr(event, "state_key"):
strings.append("state_key")
for s in strings:
if not isinstance(getattr(event, s), basestring):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
# Check that the following keys have dictionary values
# TODO
# Check that the following keys have the correct format for DAGs
# TODO
def validate_new(self, event):
self.validate(event)
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
strings = [
"body",
"msgtype",
]
self._ensure_strings(event.content, strings)
elif event.type == EventTypes.Topic:
self._ensure_strings(event.content, ["topic"])
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], basestring):
raise SynapseError(400, "Not '%s' a string type" % (s,))
...@@ -25,6 +25,7 @@ from .persistence import TransactionActions ...@@ -25,6 +25,7 @@ from .persistence import TransactionActions
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import logging import logging
...@@ -73,7 +74,7 @@ class ReplicationLayer(object): ...@@ -73,7 +74,7 @@ class ReplicationLayer(object):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.event_factory = hs.get_event_factory() self.event_builder_factory = hs.get_event_builder_factory()
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
...@@ -112,7 +113,7 @@ class ReplicationLayer(object): ...@@ -112,7 +113,7 @@ class ReplicationLayer(object):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
@log_function @log_function
def send_pdu(self, pdu): def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the """Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others. home server that should be transmitted to others.
...@@ -131,7 +132,7 @@ class ReplicationLayer(object): ...@@ -131,7 +132,7 @@ class ReplicationLayer(object):
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc. # TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order) self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug( logger.debug(
"[%s] transaction_layer.enqueue_pdu... done", "[%s] transaction_layer.enqueue_pdu... done",
...@@ -438,7 +439,9 @@ class ReplicationLayer(object): ...@@ -438,7 +439,9 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content) pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu) res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue((200, { defer.returnValue((200, {
...@@ -557,7 +560,13 @@ class ReplicationLayer(object): ...@@ -557,7 +560,13 @@ class ReplicationLayer(object):
origin, pdu.event_id, do_auth=False origin, pdu.event_id, do_auth=False
) )
if existing and (not existing.outlier or pdu.outlier): already_seen = (
existing and (
not existing.internal_metadata.outlier
or pdu.internal_metadata.outlier
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id) logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({}) defer.returnValue({})
return return
...@@ -595,7 +604,7 @@ class ReplicationLayer(object): ...@@ -595,7 +604,7 @@ class ReplicationLayer(object):
# ) # )
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.outlier: if not pdu.internal_metadata.outlier:
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context( min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id pdu.room_id
...@@ -658,19 +667,14 @@ class ReplicationLayer(object): ...@@ -658,19 +667,14 @@ class ReplicationLayer(object):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False): def event_from_pdu_json(self, pdu_json, outlier=False):
#TODO: Check we have all the PDU keys here event = FrozenEvent(
pdu_json.setdefault("hashes", {}) pdu_json
pdu_json.setdefault("signatures", {})
sender = pdu_json.pop("sender", None)
if sender is not None:
pdu_json["user_id"] = sender
state_hash = pdu_json.get("unsigned", {}).pop("state_hash", None)
if state_hash is not None:
pdu_json["state_hash"] = state_hash
return self.event_factory.create_event(
pdu_json["type"], outlier=outlier, **pdu_json
) )
event.internal_metadata.outlier = outlier
return event
class _TransactionQueue(object): class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at """This class makes sure we only have one transaction in flight at
...@@ -706,15 +710,13 @@ class _TransactionQueue(object): ...@@ -706,15 +710,13 @@ class _TransactionQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def enqueue_pdu(self, pdu, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later. # table and we'll get back to it later.
destinations = set([ destinations = set(destinations)
d for d in pdu.destinations destinations.discard(self.server_name)
if d != self.server_name
])
logger.debug("Sending to: %s", str(destinations)) logger.debug("Sending to: %s", str(destinations))
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.events.room import RoomMemberEvent from synapse.api.constants import Membership, EventTypes
from synapse.api.constants import Membership
from synapse.events.snapshot import EventContext
import logging import logging
...@@ -31,10 +32,8 @@ class BaseHandler(object): ...@@ -31,10 +32,8 @@ class BaseHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.room_lock = hs.get_room_lock_manager()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
...@@ -44,6 +43,8 @@ class BaseHandler(object): ...@@ -44,6 +43,8 @@ class BaseHandler(object):
self.signing_key = hs.config.signing_key[0] self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
...@@ -57,62 +58,100 @@ class BaseHandler(object): ...@@ -57,62 +58,100 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[], def _create_new_client_event(self, builder):
extra_users=[], suppress_auth=False,
do_invite_host=None):
yield run_on_reactor() yield run_on_reactor()
snapshot.fill_out_prev_events(event) context = EventContext()
latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [(e, h) for e, h, _ in latest_ret]
builder.prev_events = prev_events
builder.depth = depth
yield self.state_handler.annotate_event_with_state(event) state_handler = self.state_handler
ret = yield state_handler.annotate_context_with_state(
builder,
context,
)
prev_state = ret
yield self.auth.add_auth_events(event) if builder.is_state():
builder.prev_state = prev_state
logger.debug("Signing event...") yield self.auth.add_auth_events(builder, context)
add_hashes_and_signatures( add_hashes_and_signatures(
event, self.server_name, self.signing_key builder, self.server_name, self.signing_key
)
event = builder.build()
logger.debug(
"Created event %s with auth_events: %s, current state: %s",
event.event_id, context.auth_events, context.current_state,
) )
logger.debug("Signed event.") defer.returnValue(
(event, context,)
)
@defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False):
yield run_on_reactor()
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth: if not suppress_auth:
logger.debug("Authing...") self.auth.check(event, auth_events=context.auth_events)
self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed")
else:
logger.debug("Suppressed auth.")
if do_invite_host: yield self.store.persist_event(event, context=context)
federation_handler = self.hs.get_handlers().federation_handler
invite_event = yield federation_handler.send_invite(
do_invite_host,
event
)
# FIXME: We need to check if the remote changed anything else federation_handler = self.hs.get_handlers().federation_handler
event.signatures = invite_event.signatures
yield self.store.persist_event(event) if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
invitee = self.hs.parse_userid(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
destinations = set(extra_destinations) destinations = set(extra_destinations)
# Send a PDU to all hosts who have joined the room. for k, s in context.current_state.items():
for k, s in event.state_events.items():
try: try:
if k[0] == RoomMemberEvent.TYPE: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(
self.hs.parse_userid(s.state_key).domain self.hs.parse_userid(s.state_key).domain
) )
except: except SynapseError:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
event.destinations = list(destinations)
yield self.notifier.on_new_room_event(event, extra_users=extra_users) yield self.notifier.on_new_room_event(event, extra_users=extra_users)
federation_handler = self.hs.get_handlers().federation_handler yield federation_handler.handle_new_event(
yield federation_handler.handle_new_event(event, snapshot) event,
None,
destinations=destinations,
)
...@@ -18,7 +18,7 @@ from twisted.internet import defer ...@@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.errors import SynapseError, Codes, CodeMessageException
from synapse.api.events.room import RoomAliasesEvent from synapse.api.constants import EventTypes
import logging import logging
...@@ -40,7 +40,7 @@ class DirectoryHandler(BaseHandler): ...@@ -40,7 +40,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Do auth. # TODO(erikj): Do auth.
if not room_alias.is_mine: if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this. # TODO(erikj): Change this.
...@@ -64,7 +64,7 @@ class DirectoryHandler(BaseHandler): ...@@ -64,7 +64,7 @@ class DirectoryHandler(BaseHandler):
def delete_association(self, user_id, room_alias): def delete_association(self, user_id, room_alias):
# TODO Check if server admin # TODO Check if server admin
if not room_alias.is_mine: if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias) room_id = yield self.store.delete_room_alias(room_alias)
...@@ -75,7 +75,7 @@ class DirectoryHandler(BaseHandler): ...@@ -75,7 +75,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias):
room_id = None room_id = None
if room_alias.is_mine: if self.hs.is_mine(room_alias):
result = yield self.store.get_association_from_room_alias( result = yield self.store.get_association_from_room_alias(
room_alias room_alias
) )
...@@ -123,7 +123,7 @@ class DirectoryHandler(BaseHandler): ...@@ -123,7 +123,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_directory_query(self, args): def on_directory_query(self, args):
room_alias = self.hs.parse_roomalias(args["room_alias"]) room_alias = self.hs.parse_roomalias(args["room_alias"])
if not room_alias.is_mine: if not self.hs.is_mine(room_alias):
raise SynapseError( raise SynapseError(
400, "Room Alias is not hosted on this Home Server" 400, "Room Alias is not hosted on this Home Server"
) )
...@@ -148,16 +148,12 @@ class DirectoryHandler(BaseHandler): ...@@ -148,16 +148,12 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, user_id, room_id): def send_room_alias_update_event(self, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
event = self.event_factory.create_event( msg_handler = self.hs.get_handlers().message_handler
etype=RoomAliasesEvent.TYPE, yield msg_handler.create_and_send_event({
state_key=self.hs.hostname, "type": EventTypes.Aliases,
room_id=room_id, "state_key": self.hs.hostname,
user_id=user_id, "room_id": room_id,
content={"aliases": aliases}, "sender": user_id,
) "content": {"aliases": aliases},
})
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(
event, snapshot, extra_users=[user_id], suppress_auth=True
)
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.events.utils import prune_event from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError, AuthError, FederationError, SynapseError, StoreError,
) )
from synapse.api.events.room import RoomMemberEvent, RoomCreateEvent from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import Membership
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
...@@ -76,7 +76,7 @@ class FederationHandler(BaseHandler): ...@@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_event(self, event, snapshot): def handle_new_event(self, event, snapshot, destinations):
""" Takes in an event from the client to server side, that has already """ Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any been authed and handled by the state module, and sends it to any
remote home servers that may be interested. remote home servers that may be interested.
...@@ -92,12 +92,7 @@ class FederationHandler(BaseHandler): ...@@ -92,12 +92,7 @@ class FederationHandler(BaseHandler):
yield run_on_reactor() yield run_on_reactor()
pdu = event yield self.replication_layer.send_pdu(event, destinations)
if not hasattr(pdu, "destinations") or not pdu.destinations:
pdu.destinations = []
yield self.replication_layer.send_pdu(pdu)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
...@@ -140,7 +135,7 @@ class FederationHandler(BaseHandler): ...@@ -140,7 +135,7 @@ class FederationHandler(BaseHandler):
if not check_event_content_hash(event): if not check_event_content_hash(event):
logger.warn( logger.warn(
"Event content has been tampered, redacting %s, %s", "Event content has been tampered, redacting %s, %s",
event.event_id, encode_canonical_json(event.get_full_dict()) event.event_id, encode_canonical_json(event.get_dict())
) )
event = redacted_event event = redacted_event
...@@ -153,7 +148,7 @@ class FederationHandler(BaseHandler): ...@@ -153,7 +148,7 @@ class FederationHandler(BaseHandler):
event.room_id, event.room_id,
self.server_name self.server_name
) )
if not is_in_room and not event.outlier: if not is_in_room and not event.internal_metadata.outlier:
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
replication_layer = self.replication_layer replication_layer = self.replication_layer
...@@ -164,7 +159,7 @@ class FederationHandler(BaseHandler): ...@@ -164,7 +159,7 @@ class FederationHandler(BaseHandler):
) )
for e in auth_chain: for e in auth_chain:
e.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e, fetch_missing=False) yield self._handle_new_event(e, fetch_missing=False)
except: except:
...@@ -184,7 +179,7 @@ class FederationHandler(BaseHandler): ...@@ -184,7 +179,7 @@ class FederationHandler(BaseHandler):
if state: if state:
for e in state: for e in state:
e.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e) yield self._handle_new_event(e)
except: except:
...@@ -229,7 +224,7 @@ class FederationHandler(BaseHandler): ...@@ -229,7 +224,7 @@ class FederationHandler(BaseHandler):
if not backfilled: if not backfilled:
extra_users = [] extra_users = []
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
target_user_id = event.state_key target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id) target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
...@@ -238,7 +233,7 @@ class FederationHandler(BaseHandler): ...@@ -238,7 +233,7 @@ class FederationHandler(BaseHandler):
event, extra_users=extra_users event, extra_users=extra_users
) )
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire( yield self.distributor.fire(
...@@ -265,11 +260,18 @@ class FederationHandler(BaseHandler): ...@@ -265,11 +260,18 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
# FIXME (erikj): Not sure this actually works :/ # FIXME (erikj): Not sure this actually works :/
yield self.state_handler.annotate_event_with_state(event) context = EventContext()
yield self.state_handler.annotate_context_with_state(event, context)
events.append(event) events.append(
(event, context)
)
yield self.store.persist_event(event, backfilled=True) yield self.store.persist_event(
event,
context=context,
backfilled=True
)
defer.returnValue(events) defer.returnValue(events)
...@@ -286,8 +288,6 @@ class FederationHandler(BaseHandler): ...@@ -286,8 +288,6 @@ class FederationHandler(BaseHandler):
pdu=event pdu=event
) )
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
...@@ -332,42 +332,55 @@ class FederationHandler(BaseHandler): ...@@ -332,42 +332,55 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
# We should assert some things. # We should assert some things.
assert(event.type == RoomMemberEvent.TYPE) # FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == joinee) assert(event.user_id == joinee)
assert(event.state_key == joinee) assert(event.state_key == joinee)
assert(event.room_id == room_id) assert(event.room_id == room_id)
event.outlier = False event.internal_metadata.outlier = False
self.room_queues[room_id] = [] self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
event.get_pdu_json()
)
handled_events = set()
try: try:
event.event_id = self.event_factory.create_event_id() builder.event_id = self.event_builder_factory.create_event_id()
event.origin = self.hs.hostname builder.origin = self.hs.hostname
event.content = content builder.content = content
if not hasattr(event, "signatures"): if not hasattr(event, "signatures"):
event.signatures = {} builder.signatures = {}
add_hashes_and_signatures( add_hashes_and_signatures(
event, builder,
self.hs.hostname, self.hs.hostname,
self.hs.config.signing_key[0], self.hs.config.signing_key[0],
) )
new_event = builder.build()
ret = yield self.replication_layer.send_join( ret = yield self.replication_layer.send_join(
target_host, target_host,
event new_event
) )
state = ret["state"] state = ret["state"]
auth_chain = ret["auth_chain"] auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
handled_events.update([s.event_id for s in state])
handled_events.update([a.event_id for a in auth_chain])
handled_events.add(new_event.event_id)
logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join auth_chain: %s", auth_chain)
logger.debug("do_invite_join state: %s", state) logger.debug("do_invite_join state: %s", state)
logger.debug("do_invite_join event: %s", event) logger.debug("do_invite_join event: %s", new_event)
try: try:
yield self.store.store_room( yield self.store.store_room(
...@@ -380,7 +393,7 @@ class FederationHandler(BaseHandler): ...@@ -380,7 +393,7 @@ class FederationHandler(BaseHandler):
pass pass
for e in auth_chain: for e in auth_chain:
e.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e, fetch_missing=False) yield self._handle_new_event(e, fetch_missing=False)
except: except:
...@@ -391,7 +404,7 @@ class FederationHandler(BaseHandler): ...@@ -391,7 +404,7 @@ class FederationHandler(BaseHandler):
for e in state: for e in state:
# FIXME: Auth these. # FIXME: Auth these.
e.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event( yield self._handle_new_event(
e, e,
...@@ -404,13 +417,13 @@ class FederationHandler(BaseHandler): ...@@ -404,13 +417,13 @@ class FederationHandler(BaseHandler):
) )
yield self._handle_new_event( yield self._handle_new_event(
event, new_event,
state=state, state=state,
current_state=state, current_state=state,
) )
yield self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
event, extra_users=[joinee] new_event, extra_users=[joinee]
) )
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
...@@ -419,6 +432,9 @@ class FederationHandler(BaseHandler): ...@@ -419,6 +432,9 @@ class FederationHandler(BaseHandler):
del self.room_queues[room_id] del self.room_queues[room_id]
for p, origin in room_queue: for p, origin in room_queue:
if p.event_id in handled_events:
continue
try: try:
self.on_receive_pdu(origin, p, backfilled=False) self.on_receive_pdu(origin, p, backfilled=False)
except: except:
...@@ -428,25 +444,24 @@ class FederationHandler(BaseHandler): ...@@ -428,25 +444,24 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, context, user_id): def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or join event for the room and return that. We don *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event = self.event_factory.create_event( builder = self.event_builder_factory.new({
etype=RoomMemberEvent.TYPE, "type": EventTypes.Member,
content={"membership": Membership.JOIN}, "content": {"membership": Membership.JOIN},
room_id=context, "room_id": room_id,
user_id=user_id, "sender": user_id,
state_key=user_id, "state_key": user_id,
) })
snapshot = yield self.store.snapshot_room(event) event, context = yield self._create_new_client_event(
snapshot.fill_out_prev_events(event) builder=builder,
)
yield self.state_handler.annotate_event_with_state(event) self.auth.check(event, auth_events=context.auth_events)
yield self.auth.add_auth_events(event)
self.auth.check(event, auth_events=event.old_state_events)
pdu = event pdu = event
...@@ -460,12 +475,24 @@ class FederationHandler(BaseHandler): ...@@ -460,12 +475,24 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
event.outlier = False logger.debug(
"on_send_join_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
event.internal_metadata.outlier = False
yield self._handle_new_event(event) context = yield self._handle_new_event(event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
extra_users = [] extra_users = []
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
target_user_id = event.state_key target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id) target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
...@@ -474,7 +501,7 @@ class FederationHandler(BaseHandler): ...@@ -474,7 +501,7 @@ class FederationHandler(BaseHandler):
event, extra_users=extra_users event, extra_users=extra_users
) )
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire( yield self.distributor.fire(
...@@ -485,9 +512,9 @@ class FederationHandler(BaseHandler): ...@@ -485,9 +512,9 @@ class FederationHandler(BaseHandler):
destinations = set() destinations = set()
for k, s in event.state_events.items(): for k, s in context.current_state.items():
try: try:
if k[0] == RoomMemberEvent.TYPE: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(
self.hs.parse_userid(s.state_key).domain self.hs.parse_userid(s.state_key).domain
...@@ -497,14 +524,18 @@ class FederationHandler(BaseHandler): ...@@ -497,14 +524,18 @@ class FederationHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
new_pdu.destinations = list(destinations) logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
yield self.replication_layer.send_pdu(new_pdu) yield self.replication_layer.send_pdu(new_pdu, destinations)
auth_chain = yield self.store.get_auth_chain(event.event_id) auth_chain = yield self.store.get_auth_chain(event.event_id)
defer.returnValue({ defer.returnValue({
"state": event.state_events.values(), "state": context.current_state.values(),
"auth_chain": auth_chain, "auth_chain": auth_chain,
}) })
...@@ -516,7 +547,9 @@ class FederationHandler(BaseHandler): ...@@ -516,7 +547,9 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
event.outlier = True context = EventContext()
event.internal_metadata.outlier = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
...@@ -526,10 +559,11 @@ class FederationHandler(BaseHandler): ...@@ -526,10 +559,11 @@ class FederationHandler(BaseHandler):
) )
) )
yield self.state_handler.annotate_event_with_state(event) yield self.state_handler.annotate_context_with_state(event, context)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
context=context,
backfilled=False, backfilled=False,
) )
...@@ -559,13 +593,13 @@ class FederationHandler(BaseHandler): ...@@ -559,13 +593,13 @@ class FederationHandler(BaseHandler):
} }
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
if hasattr(event, "state_key"): if event and event.is_state():
# Get previous state # Get previous state
if hasattr(event, "replaces_state") and event.replaces_state: if "replaces_state" in event.unsigned:
prev_event = yield self.store.get_event( prev_id = event.unsigned["replaces_state"]
event.replaces_state if prev_id != event.event_id:
) prev_event = yield self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event results[(event.type, event.state_key)] = prev_event
else: else:
del results[(event.type, event.state_key)] del results[(event.type, event.state_key)]
...@@ -651,74 +685,81 @@ class FederationHandler(BaseHandler): ...@@ -651,74 +685,81 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True): current_state=None, fetch_missing=True):
is_new_state = yield self.state_handler.annotate_event_with_state( context = EventContext()
logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s",
event.event_id, event.signatures,
)
yield self.state_handler.annotate_context_with_state(
event, event,
context,
old_state=state old_state=state
) )
if event.old_state_events: logger.debug(
known_ids = set( "_handle_new_event: Before auth fetch: %s, sigs: %s",
[s.event_id for s in event.old_state_events.values()] event.event_id, event.signatures,
) )
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in %s",
event.event_id, e_id, known_ids,
)
raise AuthError(403, "Auth events are stale")
auth_events = event.old_state_events
else:
# We need to get the auth events from somewhere.
# TODO: Don't just hit the DBs? is_new_state = not event.internal_metadata.is_outlier()
auth_events = {} known_ids = set(
for e_id, _ in event.auth_events: [s.event_id for s in context.auth_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event( e = yield self.store.get_event(
e_id, e_id, allow_none=True,
allow_none=True,
) )
if not e: if not e:
e = yield self.replication_layer.get_pdu( # TODO: Do some conflict res to make sure that we're
event.origin, e_id, outlier=True # not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in db or %s",
event.event_id, e_id, known_ids,
) )
# FIXME: How does raising AuthError work with federation?
raise AuthError(403, "Auth events are stale")
if e and fetch_missing: context.auth_events[(e.type, e.state_key)] = e
try:
yield self.on_receive_pdu(event.origin, e, False)
except:
logger.exception(
"Failed to parse auth event %s",
e_id,
)
if not e: logger.debug(
logger.warn("Can't find auth event %s.", e_id) "_handle_new_event: Before hack: %s, sigs: %s",
event.event_id, event.signatures,
)
auth_events[(e.type, e.state_key)] = e if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
context.auth_events[(c.type, c.state_key)] = c
if event.type == RoomMemberEvent.TYPE and not event.auth_events: logger.debug(
if len(event.prev_events) == 1: "_handle_new_event: Before auth check: %s, sigs: %s",
c = yield self.store.get_event(event.prev_events[0][0]) event.event_id, event.signatures,
if c.type == RoomCreateEvent.TYPE: )
auth_events[(c.type, c.state_key)] = c
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=context.auth_events)
logger.debug(
"_handle_new_event: Before persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
context=context,
backfilled=backfilled, backfilled=backfilled,
is_new_state=(is_new_state and not backfilled), is_new_state=(is_new_state and not backfilled),
current_state=current_state, current_state=current_state,
) )
logger.debug(
"_handle_new_event: After persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
defer.returnValue(context)
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.events.validator import EventValidator
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
...@@ -32,7 +35,7 @@ class MessageHandler(BaseHandler): ...@@ -32,7 +35,7 @@ class MessageHandler(BaseHandler):
super(MessageHandler, self).__init__(hs) super(MessageHandler, self).__init__(hs)
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_factory = hs.get_event_factory() self.validator = EventValidator()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_message(self, msg_id=None, room_id=None, sender_id=None, def get_message(self, msg_id=None, room_id=None, sender_id=None,
...@@ -79,7 +82,7 @@ class MessageHandler(BaseHandler): ...@@ -79,7 +82,7 @@ class MessageHandler(BaseHandler):
self.ratelimit(event.user_id) self.ratelimit(event.user_id)
# TODO(paul): Why does 'event' not have a 'user' object? # TODO(paul): Why does 'event' not have a 'user' object?
user = self.hs.parse_userid(event.user_id) user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
snapshot = yield self.store.snapshot_room(event) snapshot = yield self.store.snapshot_room(event)
...@@ -134,19 +137,48 @@ class MessageHandler(BaseHandler): ...@@ -134,19 +137,48 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_room_data(self, event=None): def create_and_send_event(self, event_dict):
""" Stores data for a room. """ Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Persists and notifies local clients and federation.
Args: Args:
event : The room path event event_dict (dict): An entire event
stamp_event (bool) : True to stamp event content with server keys.
Raises:
SynapseError if something went wrong.
""" """
builder = self.event_builder_factory.new(event_dict)
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
joinee = self.hs.parse_userid(builder.state_key)
# If event doesn't include a display name, add one.
yield self.distributor.fire(
"collect_presencelike_data",
joinee,
builder.content
)
snapshot = yield self.store.snapshot_room(event) event, context = yield self._create_new_client_event(
builder=builder,
)
yield self._on_new_room_event(event, snapshot) if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
else:
yield self.handle_new_client_event(
event=event,
context=context,
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment