Skip to content
Snippets Groups Projects
Commit 8ecaff51 authored by Kegan Dougal's avatar Kegan Dougal
Browse files

Review comments

parent f6c48802
Branches
Tags
Loading
...@@ -18,58 +18,68 @@ to ensure idempotency when performing PUTs using the REST API.""" ...@@ -18,58 +18,68 @@ to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from synapse.api.auth import get_access_token_from_request from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HttpTransactionCache(object): def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
def __init__(self): Idempotency is based on the returned key being the same for separate
# { key : (txn_id, res_observ_defer) } requests to the same endpoint. The key is formed from the HTTP request
self.transactions = {} path and the access_token for the requesting user.
def _get_response(self, key, txn_id): Args:
try: request (twisted.web.http.Request): The incoming request. Must
(last_txn_id, res_observ_defer) = self.transactions[key] contain an access_token.
if txn_id == last_txn_id: Returns:
logger.info("get_response: Returning a response for %s", txn_id) str: A transaction key
return res_observ_defer """
except KeyError: token = get_access_token_from_request(request)
pass return request.path + "/" + token
return None
def _store_response(self, key, txn_id, res_observ_defer):
self.transactions[key] = (txn_id, res_observ_defer)
def store_client_transaction(self, request, txn_id, res_observ_defer): class HttpTransactionCache(object):
"""Stores the request/Promise<response> pair of an HTTP transaction.
Args: def __init__(self):
request (twisted.web.http.Request): The twisted HTTP request. This self.transactions = {
request must have the transaction ID as the last path segment. # $txn_key: ObservableDeferred<(res_code, res_json_body)>
res_observ_defer (Promise<tuple>): A tuple of (response code, response dict) }
txn_id (str): The transaction ID for this request.
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
See:
fetch_or_execute
""" """
self._store_response(self._get_key(request), txn_id, res_observ_defer) return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
)
def get_client_transaction(self, request, txn_id): def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
"""Retrieves a stored response if there was one. """Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args: Args:
request (twisted.web.http.Request): The twisted HTTP request. This txn_key (str): A key to ensure idempotency should fetch_or_execute be
request must have the transaction ID as the last path segment. called again at a later point in time.
txn_id (str): The transaction ID for this request. fn (function): A function which returns a tuple of
(response_code, response_dict)d
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns: Returns:
Promise: Resolves to the response tuple. synapse.util.async.ObservableDeferred which resolves to a tuple
Raises: of (response_code, response_dict).
KeyError if the transaction was not found.
""" """
res_observ_defer = self._get_response(self._get_key(request), txn_id) try:
if res_observ_defer is None: return self.transactions[txn_key]
raise KeyError("Transaction not found.") except KeyError:
return res_observ_defer pass # execute the function instead.
def _get_key(self, request): deferred = fn(*args, **kwargs)
token = get_access_token_from_request(request) observable = ObservableDeferred(deferred)
path_without_txn_id = request.path.rsplit("/", 1)[0] self.transactions[txn_key] = observable
return path_without_txn_id + "/" + token return observable
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.api.urls import CLIENT_PREFIX from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
import re import re
import logging import logging
......
...@@ -22,7 +22,6 @@ from synapse.streams.config import PaginationConfig ...@@ -22,7 +22,6 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.util.async import ObservableDeferred
from synapse.events.utils import serialize_event, format_event_for_client_v2 from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer parse_json_object_from_request, parse_string, parse_integer
...@@ -56,17 +55,11 @@ class RoomCreateRestServlet(ClientV1RestServlet): ...@@ -56,17 +55,11 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
try: observable = self.txns.fetch_or_execute_request(
res_deferred = self.txns.get_client_transaction(request, txn_id) request, self.on_POST, request
res = yield res_deferred.observe() )
defer.returnValue(res) res = yield observable.observe()
except KeyError: defer.returnValue(res)
pass
res_deferred = ObservableDeferred(self.on_POST(request))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
...@@ -217,19 +210,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet): ...@@ -217,19 +210,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id): def on_PUT(self, request, room_id, event_type, txn_id):
try: observable = self.txns.fetch_or_execute_request(
res_deferred = self.txns.get_client_transaction(request, txn_id) request, self.on_POST, request, room_id, event_type, txn_id
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, event_type, txn_id)
) )
self.txns.store_client_transaction(request, txn_id, res_deferred) res = yield observable.observe()
response = yield res_deferred.observe() defer.returnValue(res)
defer.returnValue(response)
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
...@@ -288,17 +273,11 @@ class JoinRoomAliasServlet(ClientV1RestServlet): ...@@ -288,17 +273,11 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
try: observable = self.txns.fetch_or_execute_request(
res_deferred = self.txns.get_client_transaction(request, txn_id) request, self.on_POST, request, room_identifier, txn_id
res = yield res_deferred.observe() )
defer.returnValue(res) res = yield observable.observe()
except KeyError: defer.returnValue(res)
pass
res_deferred = ObservableDeferred(self.on_POST(request, room_identifier, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing
...@@ -542,17 +521,11 @@ class RoomForgetRestServlet(ClientV1RestServlet): ...@@ -542,17 +521,11 @@ class RoomForgetRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id): def on_PUT(self, request, room_id, txn_id):
try: observable = self.txns.fetch_or_execute_request(
res_deferred = self.txns.get_client_transaction(request, txn_id) request, self.on_POST, request, room_id, txn_id
res = yield res_deferred.observe() )
defer.returnValue(res) res = yield observable.observe()
except KeyError: defer.returnValue(res)
pass
res_deferred = ObservableDeferred(self.on_POST(request, room_id, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing
...@@ -626,19 +599,11 @@ class RoomMembershipRestServlet(ClientV1RestServlet): ...@@ -626,19 +599,11 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
try: observable = self.txns.fetch_or_execute_request(
res_deferred = self.txns.get_client_transaction(request, txn_id) request, self.on_POST, request, room_id, membership_action, txn_id
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, membership_action, txn_id)
) )
self.txns.store_client_transaction(request, txn_id, res_deferred) res = yield observable.observe()
response = yield res_deferred.observe() defer.returnValue(res)
defer.returnValue(response)
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
...@@ -672,19 +637,11 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): ...@@ -672,19 +637,11 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id): def on_PUT(self, request, room_id, event_id, txn_id):
try: observable = self.txns.fetch_or_execute_request(
res_deferred = self.txns.get_client_transaction(request, txn_id) request, self.on_POST, request, room_id, event_id, txn_id
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, event_id, txn_id)
) )
self.txns.store_client_transaction(request, txn_id, res_deferred) res = yield observable.observe()
response = yield res_deferred.observe() defer.returnValue(res)
defer.returnValue(response)
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet):
......
...@@ -19,8 +19,7 @@ from twisted.internet import defer ...@@ -19,8 +19,7 @@ from twisted.internet import defer
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.util.async import ObservableDeferred
from ._base import client_v2_patterns from ._base import client_v2_patterns
...@@ -46,16 +45,10 @@ class SendToDeviceRestServlet(servlet.RestServlet): ...@@ -46,16 +45,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
try: observable = self.txns.fetch_or_execute_request(
res_deferred = self.txns.get_client_transaction(request, txn_id) request, self._put, request, message_type, txn_id
res = yield res_deferred.observe() )
defer.returnValue(res) res = yield observable.observe()
except KeyError:
pass
res_deferred = ObservableDeferred(self._put(request, message_type, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
res = yield res_deferred.observe()
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment