Skip to content
Snippets Groups Projects
Unverified Commit 618bd1ee authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Fix some error cases in the caching layer. (#5749)

There was some inconsistent behaviour in the caching layer around how
exceptions were handled - particularly synchronously-thrown ones.

This seems to be most easily handled by pushing the creation of
ObservableDeferreds down from CacheDescriptor to the Cache.
parent f16aa3a4
No related branches found
No related tags found
No related merge requests found
Fix some error cases in the caching layer.
......@@ -19,8 +19,7 @@ import logging
import threading
from collections import namedtuple
import six
from six import itervalues, string_types
from six import itervalues
from prometheus_client import Gauge
......@@ -32,7 +31,6 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache
......@@ -124,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either a Deferred or the raw result
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
......@@ -148,9 +146,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
entry = CacheEntry(deferred=value, callbacks=callbacks)
observable = ObservableDeferred(value, consumeErrors=True)
observer = defer.maybeDeferred(observable.observe)
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
......@@ -158,20 +161,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
def shuffle(result):
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
......@@ -179,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
......@@ -414,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
# If our cache_key is a string on py2, try to convert to ascii
# to save a bit of space in large caches. Py3 does this
# internally automatically.
if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
if isinstance(observer, defer.Deferred):
return make_deferred_yieldable(observer)
else:
return observer
return make_deferred_yieldable(observer)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
......@@ -543,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
# we need an observable deferred for each entry in the list,
# we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
......@@ -551,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
......
......@@ -27,6 +27,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached
from tests import unittest
......@@ -55,12 +56,15 @@ class CacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return the deferreds
self.assertIs(cache.get("key1"), d1)
self.assertIs(cache.get("key2"), d2)
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
......@@ -146,6 +150,28 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
......@@ -222,6 +248,9 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(LoggingContext.current_context(), c1)
# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)
obj = Cls()
# set off a deferred which will do a cache lookup
......@@ -268,6 +297,61 @@ class DescriptorTestCase(unittest.TestCase):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()
def test_cache_iterable(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()
@descriptors.cached(iterable=True)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
obj = Cls()
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_not_called()
def test_cache_iterable_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""
class Cls(object):
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")
obj = Cls()
# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment