Skip to content
Snippets Groups Projects
Commit be47cfa9 authored by Erik Johnston's avatar Erik Johnston
Browse files

Refactor event building into EventBuilder

This is so that everything is done in one place, making it easier to
change the event format based on room version
parent 554ca58e
Branches
Tags
No related merge requests found
...@@ -13,79 +13,156 @@ ...@@ -13,79 +13,156 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import attr
from synapse.api.constants import RoomVersions from twisted.internet import defer
from synapse.api.constants import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
MAX_DEPTH,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.types import EventID from synapse.types import EventID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from . import EventBase, FrozenEvent, _event_dict_property from . import (
_EventInternalMetadata,
event_type_from_format_version,
room_version_to_event_format,
)
def get_event_builder(room_version, key_values={}, internal_metadata_dict={}): @attr.s(slots=True, cmp=False, frozen=True)
"""Generate an event builder appropriate for the given room version class EventBuilder(object):
"""A format independent event builder used to build up the event content
before signing the event.
Args: (Note that while objects of this class are frozen, the
room_version (str): Version of the room that we're creating an content/unsigned/internal_metadata fields are still mutable)
event builder for
key_values (dict): Fields used as the basis of the new event
internal_metadata_dict (dict): Used to create the `_EventInternalMetadata`
object.
Returns: Attributes:
EventBuilder format_version (int): Event format version
room_id (str)
type (str)
sender (str)
content (dict)
unsigned (dict)
internal_metadata (_EventInternalMetadata)
_state (StateHandler)
_auth (synapse.api.Auth)
_store (DataStore)
_clock (Clock)
_hostname (str): The hostname of the server creating the event
_signing_key: The signing key to use to sign the event as the server
""" """
if room_version in {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}:
return EventBuilder(key_values, internal_metadata_dict)
else:
raise Exception(
"No event format defined for version %r" % (room_version,)
)
_state = attr.ib()
_auth = attr.ib()
_store = attr.ib()
_clock = attr.ib()
_hostname = attr.ib()
_signing_key = attr.ib()
format_version = attr.ib()
room_id = attr.ib()
type = attr.ib()
sender = attr.ib()
content = attr.ib(default=attr.Factory(dict))
unsigned = attr.ib(default=attr.Factory(dict))
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@property
def state_key(self):
if self._state_key is not None:
return self._state_key
raise AttributeError("state_key")
def is_state(self):
return self._state_key is not None
class EventBuilder(EventBase): @defer.inlineCallbacks
def __init__(self, key_values={}, internal_metadata_dict={}): def build(self, prev_event_ids):
signatures = copy.deepcopy(key_values.pop("signatures", {})) """Transform into a fully signed and hashed event
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__( Args:
key_values, prev_event_ids (list[str]): The event IDs to use as the prev events
signatures=signatures,
unsigned=unsigned, Returns:
internal_metadata_dict=internal_metadata_dict, Deferred[FrozenEvent]
"""
state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids,
)
auth_ids = yield self._auth.compute_auth_events(
self, state_ids,
) )
event_id = _event_dict_property("event_id") auth_events = yield self._store.add_event_hashes(auth_ids)
state_key = _event_dict_property("state_key") prev_events = yield self._store.add_event_hashes(prev_event_ids)
type = _event_dict_property("type")
def build(self): old_depth = yield self._store.get_max_depth_of(
return FrozenEvent.from_event(self) prev_event_ids,
)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
class EventBuilderFactory(object): event_dict = {
def __init__(self, clock, hostname): "auth_events": auth_events,
self.clock = clock "prev_events": prev_events,
self.hostname = hostname "type": self.type,
"room_id": self.room_id,
"sender": self.sender,
"content": self.content,
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
}
if self.is_state():
event_dict["state_key"] = self._state_key
self.event_id_count = 0 if self._redacts is not None:
event_dict["redacts"] = self._redacts
def create_event_id(self): defer.returnValue(
i = str(self.event_id_count) create_local_event_from_event_dict(
self.event_id_count += 1 clock=self._clock,
hostname=self._hostname,
signing_key=self._signing_key,
format_version=self.format_version,
event_dict=event_dict,
internal_metadata_dict=self.internal_metadata.get_dict(),
)
)
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID(local_part, self.hostname) class EventBuilderFactory(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
return e_id.to_string() self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
def new(self, room_version, key_values={}): def new(self, room_version, key_values):
"""Generate an event builder appropriate for the given room version """Generate an event builder appropriate for the given room version
Args: Args:
...@@ -98,27 +175,104 @@ class EventBuilderFactory(object): ...@@ -98,27 +175,104 @@ class EventBuilderFactory(object):
""" """
# There's currently only the one event version defined # There's currently only the one event version defined
if room_version not in { if room_version not in KNOWN_ROOM_VERSIONS:
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}:
raise Exception( raise Exception(
"No event format defined for version %r" % (room_version,) "No event format defined for version %r" % (room_version,)
) )
key_values["event_id"] = self.create_event_id() key_values["event_id"] = _create_event_id(self.clock, self.hostname)
return EventBuilder(
store=self.store,
state=self.state,
auth=self.auth,
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
format_version=room_version_to_event_format(room_version),
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
sender=key_values["sender"],
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
)
def create_local_event_from_event_dict(clock, hostname, signing_key,
format_version, event_dict,
internal_metadata_dict=None):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
Args:
clock (Clock)
hostname (str)
signing_key
format_version (int)
event_dict (dict)
internal_metadata_dict (dict|None)
Returns:
FrozenEvent
"""
# There's currently only the one event version defined
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format defined for version %r" % (format_version,)
)
if internal_metadata_dict is None:
internal_metadata_dict = {}
time_now = int(clock.time_msec())
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
event_dict["origin_server_ts"] = time_now
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)
event_dict["unsigned"].setdefault("age_ts", time_now - age)
event_dict.setdefault("signatures", {})
add_hashes_and_signatures(
event_dict,
hostname,
signing_key,
)
return event_type_from_format_version(format_version)(
event_dict, internal_metadata_dict=internal_metadata_dict,
)
# A counter used when generating new event IDs
_event_id_counter = 0
def _create_event_id(clock, hostname):
"""Create a new event ID
Args:
clock (Clock)
hostname (str): The server name for the event ID
Returns:
str
"""
time_now = int(self.clock.time_msec()) global _event_id_counter
key_values.setdefault("origin", self.hostname) i = str(_event_id_counter)
key_values.setdefault("origin_server_ts", time_now) _event_id_counter += 1
key_values.setdefault("unsigned", {}) local_part = str(int(clock.time())) + i + random_string(5)
age = key_values["unsigned"].pop("age", 0)
key_values["unsigned"].setdefault("age_ts", time_now - age)
key_values["signatures"] = {} e_id = EventID(local_part, hostname)
return EventBuilder(key_values=key_values,) return e_id.to_string()
...@@ -37,8 +37,7 @@ from synapse.api.errors import ( ...@@ -37,8 +37,7 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
SynapseError, SynapseError,
) )
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import builder, room_version_to_event_format
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
...@@ -72,7 +71,8 @@ class FederationClient(FederationBase): ...@@ -72,7 +71,8 @@ class FederationClient(FederationBase):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
self.event_builder_factory = hs.get_event_builder_factory() self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
...@@ -608,18 +608,10 @@ class FederationClient(FederationBase): ...@@ -608,18 +608,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
# Strip off the fields that we want to clobber. ev = builder.create_local_event_from_event_dict(
pdu_dict.pop("origin", None) self._clock, self.hostname, self.signing_key,
pdu_dict.pop("origin_server_ts", None) format_version=event_format, event_dict=pdu_dict,
pdu_dict.pop("unsigned", None)
builder = self.event_builder_factory.new(room_version, pdu_dict)
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0]
) )
ev = builder.build()
defer.returnValue( defer.returnValue(
(destination, ev, event_format) (destination, ev, event_format)
......
...@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json ...@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
...@@ -31,7 +31,6 @@ from synapse.api.errors import ( ...@@ -31,7 +31,6 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
...@@ -545,40 +544,17 @@ class EventCreationHandler(object): ...@@ -545,40 +544,17 @@ class EventCreationHandler(object):
prev_events_and_hashes = \ prev_events_and_hashes = \
yield self.store.get_prev_events_for_room(builder.room_id) yield self.store.get_prev_events_for_room(builder.room_id)
if prev_events_and_hashes:
depth = max([d for _, _, d in prev_events_and_hashes]) + 1
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(depth, MAX_DEPTH)
else:
depth = 1
prev_events = [ prev_events = [
(event_id, prev_hashes) (event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes for event_id, prev_hashes, _ in prev_events_and_hashes
] ]
builder.prev_events = prev_events event = yield builder.build(
builder.depth = depth prev_event_ids=[p for p, _ in prev_events],
context = yield self.state.compute_event_context(builder)
if requester:
context.app_service = requester.app_service
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
) )
context = yield self.state.compute_event_context(event)
event = builder.build() self.validator.validate_new(event)
logger.debug( logger.debug(
"Created event %s", "Created event %s",
......
...@@ -355,10 +355,7 @@ class HomeServer(object): ...@@ -355,10 +355,7 @@ class HomeServer(object):
return Keyring(self) return Keyring(self)
def build_event_builder_factory(self): def build_event_builder_factory(self):
return EventBuilderFactory( return EventBuilderFactory(self)
clock=self.get_clock(),
hostname=self.hostname,
)
def build_filtering(self): def build_filtering(self):
return Filtering(self) return Filtering(self)
......
...@@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, ...@@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return dict(txn) return dict(txn)
@defer.inlineCallbacks
def get_max_depth_of(self, event_ids):
"""Returns the max depth of a set of event IDs
Args:
event_ids (list[str])
Returns
Deferred[int]
"""
rows = yield self._simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=("depth",),
desc="get_max_depth_of",
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(max(row["depth"] for row in rows))
def _get_oldest_events_in_room_txn(self, txn, room_id): def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn( return self._simple_select_onecol_txn(
txn, txn,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment