Skip to content
Snippets Groups Projects
Commit 42c43cfa authored by Kegan Dougal's avatar Kegan Dougal
Browse files

Use ObservableDeferreds instead of Deferreds as they behave as intended

parent c7daf313
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.util.async import ObservableDeferred
from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
......@@ -57,14 +58,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_PUT(self, request, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = self.on_POST(request)
res_deferred = ObservableDeferred(self.on_POST(request))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred
response = yield res_deferred.observe()
defer.returnValue(response)
@defer.inlineCallbacks
......@@ -218,14 +219,14 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, event_type, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = self.on_POST(request, room_id, event_type, txn_id)
res_deferred = ObservableDeferred(self.on_POST(request, room_id, event_type, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred
response = yield res_deferred.observe()
defer.returnValue(response)
......@@ -287,14 +288,14 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
def on_PUT(self, request, room_identifier, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = self.on_POST(request, room_identifier, txn_id)
res_deferred = ObservableDeferred(self.on_POST(request, room_identifier, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred
response = yield res_deferred.observe()
defer.returnValue(response)
......@@ -541,14 +542,14 @@ class RoomForgetRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = self.on_POST(request, room_id, txn_id)
res_deferred = ObservableDeferred(self.on_POST(request, room_id, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred
response = yield res_deferred.observe()
defer.returnValue(response)
......@@ -624,15 +625,15 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred
res_deferred = ObservableDeferred(self.txns.get_client_transaction(request, txn_id))
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = self.on_POST(request, room_id, membership_action, txn_id)
res_deferred = ObservableDeffself.on_POST(request, room_id, membership_action, txn_id)
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred
response = yield res_deferred.observe()
defer.returnValue(response)
......@@ -669,14 +670,14 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, event_id, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = self.on_POST(request, room_id, event_id, txn_id)
res_deferred = ObservableDeferred(self.on_POST(request, room_id, event_id, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred
response = yield res_deferred.observe()
defer.returnValue(response)
......
......@@ -25,32 +25,32 @@ logger = logging.getLogger(__name__)
class HttpTransactionCache(object):
def __init__(self):
# { key : (txn_id, response_deferred) }
# { key : (txn_id, res_observ_defer) }
self.transactions = {}
def _get_response(self, key, txn_id):
try:
(last_txn_id, response_deferred) = self.transactions[key]
(last_txn_id, res_observ_defer) = self.transactions[key]
if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", txn_id)
return response_deferred
return res_observ_defer
except KeyError:
pass
return None
def _store_response(self, key, txn_id, response_deferred):
self.transactions[key] = (txn_id, response_deferred)
def _store_response(self, key, txn_id, res_observ_defer):
self.transactions[key] = (txn_id, res_observ_defer)
def store_client_transaction(self, request, txn_id, response_deferred):
def store_client_transaction(self, request, txn_id, res_observ_defer):
"""Stores the request/Promise<response> pair of an HTTP transaction.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
response_deferred (Promise<tuple>): A tuple of (response code, response dict)
res_observ_defer (Promise<tuple>): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request.
"""
self._store_response(self._get_key(request), txn_id, response_deferred)
self._store_response(self._get_key(request), txn_id, res_observ_defer)
def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one.
......@@ -64,10 +64,10 @@ class HttpTransactionCache(object):
Raises:
KeyError if the transaction was not found.
"""
response_deferred = self._get_response(self._get_key(request), txn_id)
if response_deferred is None:
res_observ_defer = self._get_response(self._get_key(request), txn_id)
if res_observ_defer is None:
raise KeyError("Transaction not found.")
return response_deferred
return res_observ_defer
def _get_key(self, request):
token = get_access_token_from_request(request)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment