Skip to content
Snippets Groups Projects
Commit af4a1bac authored by Kegan Dougal's avatar Kegan Dougal
Browse files

Move .observe() up to the cache to make things neater

parent 8ecaff51
No related branches found
No related tags found
No related merge requests found
...@@ -67,19 +67,18 @@ class HttpTransactionCache(object): ...@@ -67,19 +67,18 @@ class HttpTransactionCache(object):
txn_key (str): A key to ensure idempotency should fetch_or_execute be txn_key (str): A key to ensure idempotency should fetch_or_execute be
called again at a later point in time. called again at a later point in time.
fn (function): A function which returns a tuple of fn (function): A function which returns a tuple of
(response_code, response_dict)d (response_code, response_dict).
*args: Arguments to pass to fn. *args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn. **kwargs: Keyword arguments to pass to fn.
Returns: Returns:
synapse.util.async.ObservableDeferred which resolves to a tuple Deferred which resolves to a tuple of (response_code, response_dict).
of (response_code, response_dict).
""" """
try: try:
return self.transactions[txn_key] return self.transactions[txn_key].observe()
except KeyError: except KeyError:
pass # execute the function instead. pass # execute the function instead.
deferred = fn(*args, **kwargs) deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred) observable = ObservableDeferred(deferred)
self.transactions[txn_key] = observable self.transactions[txn_key] = observable
return observable return observable.observe()
...@@ -53,13 +53,10 @@ class RoomCreateRestServlet(ClientV1RestServlet): ...@@ -53,13 +53,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
client_path_patterns("/createRoom(?:/.*)?$"), client_path_patterns("/createRoom(?:/.*)?$"),
self.on_OPTIONS) self.on_OPTIONS)
@defer.inlineCallbacks
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
observable = self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request request, self.on_POST, request
) )
res = yield observable.observe()
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
...@@ -208,13 +205,10 @@ class RoomSendEventRestServlet(ClientV1RestServlet): ...@@ -208,13 +205,10 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_type, txn_id): def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented") return (200, "Not implemented")
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id): def on_PUT(self, request, room_id, event_type, txn_id):
observable = self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id request, self.on_POST, request, room_id, event_type, txn_id
) )
res = yield observable.observe()
defer.returnValue(res)
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
...@@ -271,13 +265,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet): ...@@ -271,13 +265,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
defer.returnValue((200, {"room_id": room_id})) defer.returnValue((200, {"room_id": room_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
observable = self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id request, self.on_POST, request, room_identifier, txn_id
) )
res = yield observable.observe()
defer.returnValue(res)
# TODO: Needs unit testing # TODO: Needs unit testing
...@@ -519,13 +510,10 @@ class RoomForgetRestServlet(ClientV1RestServlet): ...@@ -519,13 +510,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id): def on_PUT(self, request, room_id, txn_id):
observable = self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id request, self.on_POST, request, room_id, txn_id
) )
res = yield observable.observe()
defer.returnValue(res)
# TODO: Needs unit testing # TODO: Needs unit testing
...@@ -597,13 +585,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): ...@@ -597,13 +585,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
return False return False
return True return True
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
observable = self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id request, self.on_POST, request, room_id, membership_action, txn_id
) )
res = yield observable.observe()
defer.returnValue(res)
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
...@@ -635,13 +620,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): ...@@ -635,13 +620,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id): def on_PUT(self, request, room_id, event_id, txn_id):
observable = self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id request, self.on_POST, request, room_id, event_id, txn_id
) )
res = yield observable.observe()
defer.returnValue(res)
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet):
......
...@@ -43,13 +43,10 @@ class SendToDeviceRestServlet(servlet.RestServlet): ...@@ -43,13 +43,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache() self.txns = HttpTransactionCache()
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
observable = self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id request, self._put, request, message_type, txn_id
) )
res = yield observable.observe()
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def _put(self, request, message_type, txn_id): def _put(self, request, message_type, txn_id):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment