Skip to content
Snippets Groups Projects
Commit 96779d24 authored by Erik Johnston's avatar Erik Johnston
Browse files

Fix bug where we did not send the full auth chain to people that joined over federation

parent 88484f68
No related branches found
No related tags found
No related merge requests found
...@@ -285,7 +285,7 @@ class FederationHandler(BaseHandler): ...@@ -285,7 +285,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, event_id): def on_event_auth(self, event_id):
auth = yield self.store.get_auth_chain(event_id) auth = yield self.store.get_auth_chain([event_id])
for event in auth: for event in auth:
event.signatures.update( event.signatures.update(
...@@ -494,7 +494,10 @@ class FederationHandler(BaseHandler): ...@@ -494,7 +494,10 @@ class FederationHandler(BaseHandler):
yield self.replication_layer.send_pdu(new_pdu) yield self.replication_layer.send_pdu(new_pdu)
auth_chain = yield self.store.get_auth_chain(event.event_id) state_ids = [e.event_id for e in event.state_events.values()]
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
))
defer.returnValue({ defer.returnValue({
"state": event.state_events.values(), "state": event.state_events.values(),
......
...@@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore): ...@@ -32,15 +32,15 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively. and backfilling from another server respectively.
""" """
def get_auth_chain(self, event_id): def get_auth_chain(self, event_ids):
return self.runInteraction( return self.runInteraction(
"get_auth_chain", "get_auth_chain",
self._get_auth_chain_txn, self._get_auth_chain_txn,
event_id event_ids
) )
def _get_auth_chain_txn(self, txn, event_id): def _get_auth_chain_txn(self, txn, event_ids):
results = self._get_auth_chain_ids_txn(txn, event_id) results = self._get_auth_chain_ids_txn(txn, event_ids)
sql = "SELECT * FROM events WHERE event_id = ?" sql = "SELECT * FROM events WHERE event_id = ?"
rows = [] rows = []
...@@ -50,21 +50,21 @@ class EventFederationStore(SQLBaseStore): ...@@ -50,21 +50,21 @@ class EventFederationStore(SQLBaseStore):
return self._parse_events_txn(txn, rows) return self._parse_events_txn(txn, rows)
def get_auth_chain_ids(self, event_id): def get_auth_chain_ids(self, event_ids):
return self.runInteraction( return self.runInteraction(
"get_auth_chain_ids", "get_auth_chain_ids",
self._get_auth_chain_ids_txn, self._get_auth_chain_ids_txn,
event_id event_ids
) )
def _get_auth_chain_ids_txn(self, txn, event_id): def _get_auth_chain_ids_txn(self, txn, event_ids):
results = set() results = set()
base_sql = ( base_sql = (
"SELECT auth_id FROM event_auth WHERE %s" "SELECT auth_id FROM event_auth WHERE %s"
) )
front = set([event_id]) front = set(event_ids)
while front: while front:
sql = base_sql % ( sql = base_sql % (
" OR ".join(["event_id=?"] * len(front)), " OR ".join(["event_id=?"] * len(front)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment