Skip to content
Snippets Groups Projects
Commit b515f844 authored by Paul "LeoNerd" Evans's avatar Paul "LeoNerd" Evans
Browse files

Avoid so much copypasta between 3PU and 3PL query by unifying around a...

Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration
parent 2a91799f
Branches
Tags
No related merge requests found
...@@ -17,6 +17,7 @@ from twisted.internet import defer ...@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging import logging
import urllib import urllib
...@@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient): ...@@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pu(self, service, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) if kind == ThirdPartyEntityKind.USER:
response = None uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
try: elif kind == ThirdPartyEntityKind.LOCATION:
response = yield self.get_json(uri, fields) uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
defer.returnValue(response) else:
except Exception as ex: raise ValueError(
logger.warning("query_3pu to %s threw exception %s", uri, ex) "Unrecognised 'kind' argument %r to query_3pe()", kind
defer.returnValue([]) )
@defer.inlineCallbacks
def query_3pl(self, service, protocol, fields):
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
response = None
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
defer.returnValue(response) defer.returnValue(response)
except Exception as ex: except Exception as ex:
logger.warning("query_3pl to %s threw exception %s", uri, ex) logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([]) defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
......
...@@ -18,6 +18,7 @@ from twisted.internet import defer ...@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.types import ThirdPartyEntityKind
import logging import logging
...@@ -169,37 +170,19 @@ class ApplicationServicesHandler(object): ...@@ -169,37 +170,19 @@ class ApplicationServicesHandler(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pu(self, protocol, fields): def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol) services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([ results = yield defer.DeferredList([
self.appservice_api.query_3pu(service, protocol, fields) self.appservice_api.query_3pe(service, kind, protocol, fields)
for service in services for service in services
], consumeErrors=True) ], consumeErrors=True)
ret = [] required_field = (
for (success, result) in results: "userid" if kind == ThirdPartyEntityKind.USER else
if not success: "alias" if kind == ThirdPartyEntityKind.LOCATION else
continue None
if not isinstance(result, list): )
continue
for r in result:
if _is_valid_3pentity_result(r, field="userid"):
ret.append(r)
else:
logger.warn("Application service returned an " +
"invalid result %r", r)
defer.returnValue(ret)
@defer.inlineCallbacks
def query_3pl(self, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
self.appservice_api.query_3pl(service, protocol, fields)
for service in services
], consumeErrors=True)
ret = [] ret = []
for (success, result) in results: for (success, result) in results:
...@@ -208,7 +191,7 @@ class ApplicationServicesHandler(object): ...@@ -208,7 +191,7 @@ class ApplicationServicesHandler(object):
if not isinstance(result, list): if not isinstance(result, list):
continue continue
for r in result: for r in result:
if _is_valid_3pentity_result(r, field="alias"): if _is_valid_3pentity_result(r, field=required_field):
ret.append(r) ret.append(r)
else: else:
logger.warn("Application service returned an " + logger.warn("Application service returned an " +
......
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet): ...@@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet):
fields = request.args fields = request.args
del fields["access_token"] del fields["access_token"]
results = yield self.appservice_handler.query_3pu(protocol, fields) results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results)) defer.returnValue((200, results))
...@@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet): ...@@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet):
fields = request.args fields = request.args
del fields["access_token"] del fields["access_token"]
results = yield self.appservice_handler.query_3pl(protocol, fields) results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results)) defer.returnValue((200, results))
......
...@@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): ...@@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment