Skip to content
Snippets Groups Projects
Commit 30961182 authored by Erik Johnston's avatar Erik Johnston
Browse files

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/state_ids

parents 778fa85f c1a133a6
Branches
Tags
No related merge requests found
...@@ -85,3 +85,8 @@ class RoomCreationPreset(object): ...@@ -85,3 +85,8 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"
...@@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1" ...@@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
...@@ -88,6 +88,8 @@ class ApplicationService(object): ...@@ -88,6 +88,8 @@ class ApplicationService(object):
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
# .protocols is a publicly visible field
if protocols: if protocols:
self.protocols = set(protocols) self.protocols = set(protocols)
else: else:
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind from synapse.util.caches.response_cache import ResponseCache
import logging import logging
import urllib import urllib
...@@ -25,6 +26,12 @@ import urllib ...@@ -25,6 +26,12 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
...@@ -56,6 +63,8 @@ class ApplicationServiceApi(SimpleHttpClient): ...@@ -56,6 +63,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs) super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
...@@ -97,16 +106,20 @@ class ApplicationServiceApi(SimpleHttpClient): ...@@ -97,16 +106,20 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid" required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias" required_field = "alias"
else: else:
raise ValueError( raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind "Unrecognised 'kind' argument %r to query_3pe()", kind
) )
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
)
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
if not isinstance(response, list): if not isinstance(response, list):
...@@ -131,6 +144,26 @@ class ApplicationServiceApi(SimpleHttpClient): ...@@ -131,6 +144,26 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex) logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([]) defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
)
try:
defer.returnValue((yield self.get_json(uri, {})))
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
defer.returnValue({})
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)
......
...@@ -175,6 +175,16 @@ class ApplicationServicesHandler(object): ...@@ -175,6 +175,16 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def get_3pe_protocols(self):
services = yield self.store.get_app_services()
protocols = {}
for s in services:
for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
defer.returnValue(protocols)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_event(self, event): def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
......
...@@ -387,7 +387,9 @@ class FederationHandler(BaseHandler): ...@@ -387,7 +387,9 @@ class FederationHandler(BaseHandler):
)).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a}) auth_events.update({a.event_id: a for a in results if a})
required_auth.update( required_auth.update(
a_id for event in results for a_id, _ in event.auth_events if event a_id
for event in results if event
for a_id, _ in event.auth_events
) )
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)
......
...@@ -18,15 +18,32 @@ import logging ...@@ -18,15 +18,32 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols()
defer.returnValue((200, protocols))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
...@@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet): ...@@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
...@@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet): ...@@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)
...@@ -269,10 +269,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): ...@@ -269,10 +269,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment