Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Dict, List, Sequence, Tuple
from unittest.mock import patch
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
if typing.TYPE_CHECKING:
from synapse.server import HomeServer
class StateGroupInflightCachingTestCase(HomeserverTestCase):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
) -> None:
self.state_storage = homeserver.get_storage().state
self.state_datastore = homeserver.get_datastores().state
# Patch out the `_get_state_groups_from_groups`.
# This is useful because it lets us pretend we have a slow database.
get_state_groups_patch = patch.object(
self.state_datastore,
"_get_state_groups_from_groups",
self._fake_get_state_groups_from_groups,
)
get_state_groups_patch.start()
self.addCleanup(get_state_groups_patch.stop)
self.get_state_group_calls: List[
Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
] = []
def _fake_get_state_groups_from_groups(
self, groups: Sequence[int], state_filter: StateFilter
) -> "Deferred[Dict[int, StateMap[str]]]":
d: Deferred[Dict[int, StateMap[str]]] = Deferred()
self.get_state_group_calls.append((tuple(groups), state_filter, d))
return d
def _complete_request_fake(
self,
groups: Tuple[int, ...],
state_filter: StateFilter,
d: "Deferred[Dict[int, StateMap[str]]]",
) -> None:
"""
Assemble a fake database response and complete the database request.
"""
result: Dict[int, StateMap[str]] = {}
for group in groups:
group_result: MutableStateMap[str] = {}
result[group] = group_result
for state_type, state_keys in state_filter.types.items():
if state_keys is None:
group_result[(state_type, "a")] = "xyz"
group_result[(state_type, "b")] = "xyz"
else:
for state_key in state_keys:
group_result[(state_type, state_key)] = "abc"
if state_filter.include_others:
group_result[("other.event.type", "state.key")] = "123"
d.callback(result)
def test_duplicate_requests_deduplicated(self) -> None:
"""
Tests that duplicate requests for state are deduplicated.
This test:
- requests some state (state group 42, 'all' state filter)
- requests it again, before the first request finishes
- checks to see that only one database query was made
- completes the database query
- checks that both requests see the same retrieved state
"""
req1 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)
# This should have gone to the database
self.assertEqual(len(self.get_state_group_calls), 1)
self.assertFalse(req1.called)
req2 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)
# No more calls should have gone to the database
self.assertEqual(len(self.get_state_group_calls), 1)
self.assertFalse(req1.called)
self.assertFalse(req2.called)
groups, sf, d = self.get_state_group_calls[0]
self.assertEqual(groups, (42,))
self.assertEqual(sf, StateFilter.all())
# Now we can complete the request
self._complete_request_fake(groups, sf, d)
self.assertEqual(
self.get_success(req1), {("other.event.type", "state.key"): "123"}
)
self.assertEqual(
self.get_success(req2), {("other.event.type", "state.key"): "123"}
)