Skip to content
Snippets Groups Projects
Commit 8ecaff51 authored by Kegan Dougal's avatar Kegan Dougal
Browse files

Review comments

parent f6c48802
No related branches found
No related tags found
No related merge requests found
......@@ -18,58 +18,68 @@ to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
logger = logging.getLogger(__name__)
class HttpTransactionCache(object):
def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
def __init__(self):
# { key : (txn_id, res_observ_defer) }
self.transactions = {}
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
def _get_response(self, key, txn_id):
try:
(last_txn_id, res_observ_defer) = self.transactions[key]
if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", txn_id)
return res_observ_defer
except KeyError:
pass
return None
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token
def _store_response(self, key, txn_id, res_observ_defer):
self.transactions[key] = (txn_id, res_observ_defer)
def store_client_transaction(self, request, txn_id, res_observ_defer):
"""Stores the request/Promise<response> pair of an HTTP transaction.
class HttpTransactionCache(object):
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
res_observ_defer (Promise<tuple>): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request.
def __init__(self):
self.transactions = {
# $txn_key: ObservableDeferred<(res_code, res_json_body)>
}
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
See:
fetch_or_execute
"""
self._store_response(self._get_key(request), txn_id, res_observ_defer)
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
)
def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one.
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
txn_id (str): The transaction ID for this request.
txn_key (str): A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn (function): A function which returns a tuple of
(response_code, response_dict)d
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
Promise: Resolves to the response tuple.
Raises:
KeyError if the transaction was not found.
synapse.util.async.ObservableDeferred which resolves to a tuple
of (response_code, response_dict).
"""
res_observ_defer = self._get_response(self._get_key(request), txn_id)
if res_observ_defer is None:
raise KeyError("Transaction not found.")
return res_observ_defer
try:
return self.transactions[txn_key]
except KeyError:
pass # execute the function instead.
def _get_key(self, request):
token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
self.transactions[txn_key] = observable
return observable
......@@ -18,7 +18,8 @@
from synapse.http.servlet import RestServlet
from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionCache
from synapse.rest.client.transactions import HttpTransactionCache
import re
import logging
......
......@@ -22,7 +22,6 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.util.async import ObservableDeferred
from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
......@@ -56,17 +55,11 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self.on_POST(request))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request
)
res = yield observable.observe()
defer.returnValue(res)
@defer.inlineCallbacks
def on_POST(self, request):
......@@ -217,19 +210,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, event_type, txn_id)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
)
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
res = yield observable.observe()
defer.returnValue(res)
# TODO: Needs unit testing for room ID + alias joins
......@@ -288,17 +273,11 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self.on_POST(request, room_identifier, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
res = yield observable.observe()
defer.returnValue(res)
# TODO: Needs unit testing
......@@ -542,17 +521,11 @@ class RoomForgetRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self.on_POST(request, room_id, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
)
res = yield observable.observe()
defer.returnValue(res)
# TODO: Needs unit testing
......@@ -626,19 +599,11 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, membership_action, txn_id)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
)
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
res = yield observable.observe()
defer.returnValue(res)
class RoomRedactEventRestServlet(ClientV1RestServlet):
......@@ -672,19 +637,11 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(
self.on_POST(request, room_id, event_id, txn_id)
observable = self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
)
self.txns.store_client_transaction(request, txn_id, res_deferred)
response = yield res_deferred.observe()
defer.returnValue(response)
res = yield observable.observe()
defer.returnValue(res)
class RoomTypingRestServlet(ClientV1RestServlet):
......
......@@ -19,8 +19,7 @@ from twisted.internet import defer
from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionCache
from synapse.util.async import ObservableDeferred
from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns
......@@ -46,16 +45,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id):
try:
res_deferred = self.txns.get_client_transaction(request, txn_id)
res = yield res_deferred.observe()
defer.returnValue(res)
except KeyError:
pass
res_deferred = ObservableDeferred(self._put(request, message_type, txn_id))
self.txns.store_client_transaction(request, txn_id, res_deferred)
res = yield res_deferred.observe()
observable = self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
res = yield observable.observe()
defer.returnValue(res)
@defer.inlineCallbacks
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment