Skip to content
Snippets Groups Projects
Commit 1f426970 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Push some deferred wrangling down into DeferredCache

parent 7b716953
No related branches found
No related tags found
No related merge requests found
Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.
......@@ -57,7 +57,7 @@ class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
will return a Deferred.
"""
__slots__ = (
......@@ -130,16 +130,22 @@ class DeferredCache(Generic[KT, VT]):
key: KT,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
) -> Union[ObservableDeferred, VT]:
) -> defer.Deferred:
"""Looks the key up in the caches.
For symmetry with set(), this method does *not* follow the synapse logcontext
rules: the logcontext will not be cleared on return, and the Deferred will run
its callbacks in the sentinel context. In other words: wrap the result with
make_deferred_yieldable() before `await`ing it.
Args:
key(tuple)
callback(fn): Gets called when the entry in the cache is invalidated
key:
callback: Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the result itself
A Deferred which completes with the result. Note that this may later fail
if there is an ongoing set() operation which later completes with a failure.
Raises:
KeyError if the key is not found in the cache
......@@ -152,7 +158,7 @@ class DeferredCache(Generic[KT, VT]):
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
return val.deferred
return val.deferred.observe()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
......@@ -160,7 +166,7 @@ class DeferredCache(Generic[KT, VT]):
if val2 is _Sentinel.sentinel:
raise KeyError()
else:
return val2
return defer.succeed(val2)
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
......@@ -173,7 +179,36 @@ class DeferredCache(Generic[KT, VT]):
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
) -> ObservableDeferred:
) -> defer.Deferred:
"""Adds a new entry to the cache (or updates an existing one).
The given `value` *must* be a Deferred.
First any existing entry for the same key is invalidated. Then a new entry
is added to the cache for the given key.
Until the `value` completes, calls to `get()` for the key will also result in an
incomplete Deferred, which will ultimately complete with the same result as
`value`.
If `value` completes successfully, subsequent calls to `get()` will then return
a completed deferred with the same result. If it *fails*, the cache is
invalidated and subequent calls to `get()` will raise a KeyError.
If another call to `set()` happens before `value` completes, then (a) any
invalidation callbacks registered in the interim will be called, (b) any
`get()`s in the interim will continue to complete with the result from the
*original* `value`, (c) any future calls to `get()` will complete with the
result from the *new* `value`.
It is expected that `value` does *not* follow the synapse logcontext rules - ie,
if it is incomplete, it runs its callbacks in the sentinel context.
Args:
key: Key to be set
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
......@@ -187,6 +222,8 @@ class DeferredCache(Generic[KT, VT]):
if existing_entry:
existing_entry.invalidate()
# XXX: why don't we invalidate the entry in `self.cache` yet?
self._pending_deferred_cache[key] = entry
def compare_and_pop():
......@@ -230,7 +267,9 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
......
......@@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
......@@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
) # type: DeferredCache[Tuple, Any]
) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
......@@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase):
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
observer = defer.succeed(cached_result_d)
ret = cache.get(cache_key, callback=invalidate_callback)
except KeyError:
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback)
def onErr(f):
cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
return make_deferred_yieldable(observer)
return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
......@@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
cache = cached_method.cache
cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
num_args = cached_method.num_args
@functools.wraps(self.orig)
......@@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
if not res.called:
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
results[arg] = res.get_result()
results[arg] = res.result
except KeyError:
missing.add(arg)
......
......@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache
from tests.unittest import TestCase
class DeferredCacheTestCase(unittest.TestCase):
class DeferredCacheTestCase(TestCase):
def test_empty(self):
cache = DeferredCache("test")
failed = False
......@@ -36,7 +37,7 @@ class DeferredCacheTestCase(unittest.TestCase):
cache = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEquals(cache.get("foo"), 123)
self.assertEquals(self.successResultOf(cache.get("foo")), 123)
def test_get_immediate(self):
cache = DeferredCache("test")
......@@ -82,16 +83,15 @@ class DeferredCacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# lookup should return pending deferreds
self.assertFalse(cache.get("key1").called)
self.assertFalse(cache.get("key2").called)
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now the cache will return a completed deferred
self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
# now do the invalidation
cache.invalidate_all()
......
......@@ -27,7 +27,6 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached
......@@ -419,9 +418,9 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
a = A()
a.func.prefill(("foo",), ObservableDeferred(d))
a.func.prefill(("foo",), 456)
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(a.func("foo").result, 456)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment