diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 3471afd7e744dc0018e70b5793b3904a2df6a8c9..7105ee21dc2997c9f1b2245e79deb37ef51a0887 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -358,7 +358,7 @@ class Auth(object):
     def add_auth_events(self, builder, context):
         yield run_on_reactor()
 
-        auth_ids = self.compute_auth_events(builder, context)
+        auth_ids = self.compute_auth_events(builder, context.current_state)
 
         auth_events_entries = yield self.store.add_event_hashes(
             auth_ids
@@ -372,26 +372,26 @@ class Auth(object):
             if v.event_id in auth_ids
         }
 
-    def compute_auth_events(self, event, context):
+    def compute_auth_events(self, event, current_state):
         if event.type == EventTypes.Create:
             return []
 
         auth_ids = []
 
         key = (EventTypes.PowerLevels, "", )
-        power_level_event = context.current_state.get(key)
+        power_level_event = current_state.get(key)
 
         if power_level_event:
             auth_ids.append(power_level_event.event_id)
 
         key = (EventTypes.JoinRules, "", )
-        join_rule_event = context.current_state.get(key)
+        join_rule_event = current_state.get(key)
 
         key = (EventTypes.Member, event.user_id, )
-        member_event = context.current_state.get(key)
+        member_event = current_state.get(key)
 
         key = (EventTypes.Create, "", )
-        create_event = context.current_state.get(key)
+        create_event = current_state.get(key)
         if create_event:
             auth_ids.append(create_event.event_id)
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a968a8736056a62d79064ee079b9f8c536302cbe..04a4689483270511cf5a6449e27b9d620cf93faf 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -842,7 +842,9 @@ class FederationHandler(BaseHandler):
             logger.debug("Different auth: %s", different_auth)
 
             # 1. Get what we think is the auth chain.
-            auth_ids = self.auth.compute_auth_events(event, context)
+            auth_ids = self.auth.compute_auth_events(
+                event, context.current_state
+            )
             local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
             try:
diff --git a/synapse/state.py b/synapse/state.py
index 6a6fb8aea04955587b1e858a6d26071e65523b27..695a5e7ac434d7148fc45fb51207c618b6b17f68 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -103,7 +103,9 @@ class StateHandler(object):
             context.state_group = None
 
             if hasattr(event, "auth_events") and event.auth_events:
-                auth_ids = zip(*event.auth_events)[0]
+                auth_ids = self.hs.get_auth().compute_auth_events(
+                    event, context.current_state
+                )
                 context.auth_events = {
                     k: v
                     for k, v in context.current_state.items()
@@ -149,7 +151,9 @@ class StateHandler(object):
                 event.unsigned["replaces_state"] = replaces.event_id
 
         if hasattr(event, "auth_events") and event.auth_events:
-            auth_ids = zip(*event.auth_events)[0]
+            auth_ids = self.hs.get_auth().compute_auth_events(
+                event, context.current_state
+            )
             context.auth_events = {
                 k: v
                 for k, v in context.current_state.items()