Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
synapse
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Timo Ley
synapse
Commits
1251d017
Commit
1251d017
authored
10 years ago
by
Mark Haines
Browse files
Options
Downloads
Plain Diff
Merge pull request #38 from matrix-org/new_state_resolution
New state resolution
parents
bd03947c
3d7026e7
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
synapse/state.py
+84
-50
84 additions, 50 deletions
synapse/state.py
tests/test_state.py
+357
-71
357 additions, 71 deletions
tests/test_state.py
with
441 additions
and
121 deletions
synapse/state.py
+
84
−
50
View file @
1251d017
...
@@ -19,6 +19,7 @@ from twisted.internet import defer
...
@@ -19,6 +19,7 @@ from twisted.internet import defer
from
synapse.util.logutils
import
log_function
from
synapse.util.logutils
import
log_function
from
synapse.util.async
import
run_on_reactor
from
synapse.util.async
import
run_on_reactor
from
synapse.api.constants
import
EventTypes
from
synapse.api.constants
import
EventTypes
from
synapse.api.errors
import
AuthError
from
synapse.events.snapshot
import
EventContext
from
synapse.events.snapshot
import
EventContext
from
collections
import
namedtuple
from
collections
import
namedtuple
...
@@ -36,12 +37,16 @@ def _get_state_key_from_event(event):
...
@@ -36,12 +37,16 @@ def _get_state_key_from_event(event):
KeyStateTuple
=
namedtuple
(
"
KeyStateTuple
"
,
(
"
context
"
,
"
type
"
,
"
state_key
"
))
KeyStateTuple
=
namedtuple
(
"
KeyStateTuple
"
,
(
"
context
"
,
"
type
"
,
"
state_key
"
))
AuthEventTypes
=
(
EventTypes
.
Create
,
EventTypes
.
Member
,
EventTypes
.
PowerLevels
,)
class
StateHandler
(
object
):
class
StateHandler
(
object
):
"""
Responsible for doing state conflict resolution.
"""
Responsible for doing state conflict resolution.
"""
"""
def
__init__
(
self
,
hs
):
def
__init__
(
self
,
hs
):
self
.
store
=
hs
.
get_datastore
()
self
.
store
=
hs
.
get_datastore
()
self
.
hs
=
hs
@defer.inlineCallbacks
@defer.inlineCallbacks
def
get_current_state
(
self
,
room_id
,
event_type
=
None
,
state_key
=
""
):
def
get_current_state
(
self
,
room_id
,
event_type
=
None
,
state_key
=
""
):
...
@@ -210,64 +215,93 @@ class StateHandler(object):
...
@@ -210,64 +215,93 @@ class StateHandler(object):
else
:
else
:
prev_states
=
[]
prev_states
=
[]
auth_events
=
{
k
:
e
for
k
,
e
in
unconflicted_state
.
items
()
if
k
[
0
]
in
AuthEventTypes
}
try
:
try
:
new_state
=
{}
resolved_state
=
self
.
_resolve_state_events
(
new_state
.
update
(
unconflicted_state
)
conflicted_state
,
auth_events
for
key
,
events
in
conflicted_state
.
items
():
)
new_state
[
key
]
=
self
.
_resolve_state_events
(
events
)
except
:
except
:
logger
.
exception
(
"
Failed to resolve state
"
)
logger
.
exception
(
"
Failed to resolve state
"
)
raise
raise
defer
.
returnValue
((
None
,
new_state
,
prev_states
))
new_state
=
unconflicted_state
new_state
.
update
(
resolved_state
)
def
_get_power_level_from_event_state
(
self
,
event
,
user_id
):
if
hasattr
(
event
,
"
old_state_events
"
)
and
event
.
old_state_events
:
key
=
(
EventTypes
.
PowerLevels
,
""
,
)
power_level_event
=
event
.
old_state_events
.
get
(
key
)
level
=
None
if
power_level_event
:
level
=
power_level_event
.
content
.
get
(
"
users
"
,
{}).
get
(
user_id
)
if
not
level
:
level
=
power_level_event
.
content
.
get
(
"
users_default
"
,
0
)
return
level
defer
.
returnValue
((
None
,
new_state
,
prev_states
))
else
:
return
0
@log_function
@log_function
def
_resolve_state_events
(
self
,
events
):
def
_resolve_state_events
(
self
,
conflicted_state
,
auth_events
):
curr_events
=
events
"""
This is where we actually decide which of the conflicted state to
use.
new_powers
=
[
self
.
_get_power_level_from_event_state
(
e
,
e
.
user_id
)
We resolve conflicts in the following order:
for
e
in
curr_events
1. power levels
]
2. memberships
3. other events.
new_powers
=
[
"""
int
(
p
)
if
p
else
0
for
p
in
new_powers
resolved_state
=
{}
]
power_key
=
(
EventTypes
.
PowerLevels
,
""
)
if
power_key
in
conflicted_state
.
items
():
power_levels
=
conflicted_state
[
power_key
]
resolved_state
[
power_key
]
=
self
.
_resolve_auth_events
(
power_levels
)
auth_events
.
update
(
resolved_state
)
for
key
,
events
in
conflicted_state
.
items
():
if
key
[
0
]
==
EventTypes
.
Member
:
resolved_state
[
key
]
=
self
.
_resolve_auth_events
(
events
,
auth_events
)
max_power
=
max
(
new_powers
)
auth_events
.
update
(
resolved_state
)
curr_events
=
[
for
key
,
events
in
conflicted_state
.
items
():
z
[
0
]
for
z
in
zip
(
curr_events
,
new_powers
)
if
key
not
in
resolved_state
:
if
z
[
1
]
==
max_power
resolved_state
[
key
]
=
self
.
_resolve_normal_events
(
]
events
,
auth_events
)
if
not
curr_events
:
return
resolved_state
raise
RuntimeError
(
"
Max didn
'
t get a max?
"
)
elif
len
(
curr_events
)
==
1
:
def
_resolve_auth_events
(
self
,
events
,
auth_events
):
return
curr_events
[
0
]
reverse
=
[
i
for
i
in
reversed
(
self
.
_ordered_events
(
events
))]
# TODO: For now, just choose the one with the largest event_id.
auth_events
=
dict
(
auth_events
)
return
(
sorted
(
prev_event
=
reverse
[
0
]
curr_events
,
for
event
in
reverse
[
1
:]:
key
=
lambda
e
:
hashlib
.
sha1
(
auth_events
[(
prev_event
.
type
,
prev_event
.
state_key
)]
=
prev_event
e
.
event_id
+
e
.
user_id
+
e
.
room_id
+
e
.
type
try
:
).
hexdigest
()
# FIXME: hs.get_auth() is bad style, but we need to do it to
)[
0
]
# get around circular deps.
)
self
.
hs
.
get_auth
().
check
(
event
,
auth_events
)
prev_event
=
event
except
AuthError
:
return
prev_event
return
event
def
_resolve_normal_events
(
self
,
events
,
auth_events
):
for
event
in
self
.
_ordered_events
(
events
):
try
:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self
.
hs
.
get_auth
().
check
(
event
,
auth_events
)
return
event
except
AuthError
:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return
event
def
_ordered_events
(
self
,
events
):
def
key_func
(
e
):
return
-
int
(
e
.
depth
),
hashlib
.
sha1
(
e
.
event_id
).
hexdigest
()
return
sorted
(
events
,
key
=
key_func
)
This diff is collapsed.
Click to expand it.
tests/test_state.py
+
357
−
71
View file @
1251d017
...
@@ -16,11 +16,120 @@
...
@@ -16,11 +16,120 @@
from
tests
import
unittest
from
tests
import
unittest
from
twisted.internet
import
defer
from
twisted.internet
import
defer
from
synapse.events
import
FrozenEvent
from
synapse.api.auth
import
Auth
from
synapse.api.constants
import
EventTypes
,
Membership
from
synapse.state
import
StateHandler
from
synapse.state
import
StateHandler
from
mock
import
Mock
from
mock
import
Mock
_next_event_id
=
1000
def
create_event
(
name
=
None
,
type
=
None
,
state_key
=
None
,
depth
=
2
,
event_id
=
None
,
prev_events
=
[],
**
kwargs
):
global
_next_event_id
if
not
event_id
:
_next_event_id
+=
1
event_id
=
str
(
_next_event_id
)
if
not
name
:
if
state_key
is
not
None
:
name
=
"
<%s-%s, %s>
"
%
(
type
,
state_key
,
event_id
,)
else
:
name
=
"
<%s, %s>
"
%
(
type
,
event_id
,)
d
=
{
"
event_id
"
:
event_id
,
"
type
"
:
type
,
"
sender
"
:
"
@user_id:example.com
"
,
"
room_id
"
:
"
!room_id:example.com
"
,
"
depth
"
:
depth
,
"
prev_events
"
:
prev_events
,
}
if
state_key
is
not
None
:
d
[
"
state_key
"
]
=
state_key
d
.
update
(
kwargs
)
event
=
FrozenEvent
(
d
)
return
event
class
StateGroupStore
(
object
):
def
__init__
(
self
):
self
.
_event_to_state_group
=
{}
self
.
_group_to_state
=
{}
self
.
_next_group
=
1
def
get_state_groups
(
self
,
event_ids
):
groups
=
{}
for
event_id
in
event_ids
:
group
=
self
.
_event_to_state_group
.
get
(
event_id
)
if
group
:
groups
[
group
]
=
self
.
_group_to_state
[
group
]
return
defer
.
succeed
(
groups
)
def
store_state_groups
(
self
,
event
,
context
):
if
context
.
current_state
is
None
:
return
state_events
=
context
.
current_state
if
event
.
is_state
():
state_events
[(
event
.
type
,
event
.
state_key
)]
=
event
state_group
=
context
.
state_group
if
not
state_group
:
state_group
=
self
.
_next_group
self
.
_next_group
+=
1
self
.
_group_to_state
[
state_group
]
=
state_events
.
values
()
self
.
_event_to_state_group
[
event
.
event_id
]
=
state_group
class
DictObj
(
dict
):
def
__init__
(
self
,
**
kwargs
):
super
(
DictObj
,
self
).
__init__
(
kwargs
)
self
.
__dict__
=
self
class
Graph
(
object
):
def
__init__
(
self
,
nodes
,
edges
):
events
=
{}
clobbered
=
set
(
events
.
keys
())
for
event_id
,
fields
in
nodes
.
items
():
refs
=
edges
.
get
(
event_id
)
if
refs
:
clobbered
.
difference_update
(
refs
)
prev_events
=
[(
r
,
{})
for
r
in
refs
]
else
:
prev_events
=
[]
events
[
event_id
]
=
create_event
(
event_id
=
event_id
,
prev_events
=
prev_events
,
**
fields
)
self
.
_leaves
=
clobbered
self
.
_events
=
sorted
(
events
.
values
(),
key
=
lambda
e
:
e
.
depth
)
def
walk
(
self
):
return
iter
(
self
.
_events
)
def
get_leaves
(
self
):
return
(
self
.
_events
[
i
]
for
i
in
self
.
_leaves
)
class
StateTestCase
(
unittest
.
TestCase
):
class
StateTestCase
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
store
=
Mock
(
self
.
store
=
Mock
(
...
@@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
...
@@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
"
add_event_hashes
"
,
"
add_event_hashes
"
,
]
]
)
)
hs
=
Mock
(
spec
=
[
"
get_datastore
"
])
hs
=
Mock
(
spec
=
[
"
get_datastore
"
,
"
get_auth
"
,
"
get_state_handler
"
])
hs
.
get_datastore
.
return_value
=
self
.
store
hs
.
get_datastore
.
return_value
=
self
.
store
hs
.
get_state_handler
.
return_value
=
None
hs
.
get_auth
.
return_value
=
Auth
(
hs
)
self
.
state
=
StateHandler
(
hs
)
self
.
state
=
StateHandler
(
hs
)
self
.
event_id
=
0
self
.
event_id
=
0
@defer.inlineCallbacks
def
test_branch_no_conflict
(
self
):
graph
=
Graph
(
nodes
=
{
"
START
"
:
DictObj
(
type
=
EventTypes
.
Create
,
state_key
=
""
,
depth
=
1
,
),
"
A
"
:
DictObj
(
type
=
EventTypes
.
Message
,
depth
=
2
,
),
"
B
"
:
DictObj
(
type
=
EventTypes
.
Message
,
depth
=
3
,
),
"
C
"
:
DictObj
(
type
=
EventTypes
.
Name
,
state_key
=
""
,
depth
=
3
,
),
"
D
"
:
DictObj
(
type
=
EventTypes
.
Message
,
depth
=
4
,
),
},
edges
=
{
"
A
"
:
[
"
START
"
],
"
B
"
:
[
"
A
"
],
"
C
"
:
[
"
A
"
],
"
D
"
:
[
"
B
"
,
"
C
"
]
}
)
store
=
StateGroupStore
()
self
.
store
.
get_state_groups
.
side_effect
=
store
.
get_state_groups
context_store
=
{}
for
event
in
graph
.
walk
():
context
=
yield
self
.
state
.
compute_event_context
(
event
)
store
.
store_state_groups
(
event
,
context
)
context_store
[
event
.
event_id
]
=
context
self
.
assertEqual
(
2
,
len
(
context_store
[
"
D
"
].
current_state
))
@defer.inlineCallbacks
def
test_branch_basic_conflict
(
self
):
graph
=
Graph
(
nodes
=
{
"
START
"
:
DictObj
(
type
=
EventTypes
.
Create
,
state_key
=
"
creator
"
,
content
=
{
"
membership
"
:
"
@user_id:example.com
"
},
depth
=
1
,
),
"
A
"
:
DictObj
(
type
=
EventTypes
.
Member
,
state_key
=
"
@user_id:example.com
"
,
content
=
{
"
membership
"
:
Membership
.
JOIN
},
membership
=
Membership
.
JOIN
,
depth
=
2
,
),
"
B
"
:
DictObj
(
type
=
EventTypes
.
Name
,
state_key
=
""
,
depth
=
3
,
),
"
C
"
:
DictObj
(
type
=
EventTypes
.
Name
,
state_key
=
""
,
depth
=
4
,
),
"
D
"
:
DictObj
(
type
=
EventTypes
.
Message
,
depth
=
5
,
),
},
edges
=
{
"
A
"
:
[
"
START
"
],
"
B
"
:
[
"
A
"
],
"
C
"
:
[
"
A
"
],
"
D
"
:
[
"
B
"
,
"
C
"
]
}
)
store
=
StateGroupStore
()
self
.
store
.
get_state_groups
.
side_effect
=
store
.
get_state_groups
context_store
=
{}
for
event
in
graph
.
walk
():
context
=
yield
self
.
state
.
compute_event_context
(
event
)
store
.
store_state_groups
(
event
,
context
)
context_store
[
event
.
event_id
]
=
context
self
.
assertSetEqual
(
{
"
START
"
,
"
A
"
,
"
C
"
},
{
e
.
event_id
for
e
in
context_store
[
"
D
"
].
current_state
.
values
()}
)
@defer.inlineCallbacks
def
test_branch_have_banned_conflict
(
self
):
graph
=
Graph
(
nodes
=
{
"
START
"
:
DictObj
(
type
=
EventTypes
.
Create
,
state_key
=
"
creator
"
,
content
=
{
"
membership
"
:
"
@user_id:example.com
"
},
depth
=
1
,
),
"
A
"
:
DictObj
(
type
=
EventTypes
.
Member
,
state_key
=
"
@user_id:example.com
"
,
content
=
{
"
membership
"
:
Membership
.
JOIN
},
membership
=
Membership
.
JOIN
,
depth
=
2
,
),
"
B
"
:
DictObj
(
type
=
EventTypes
.
Name
,
state_key
=
""
,
depth
=
3
,
),
"
C
"
:
DictObj
(
type
=
EventTypes
.
Member
,
state_key
=
"
@user_id_2:example.com
"
,
content
=
{
"
membership
"
:
Membership
.
BAN
},
membership
=
Membership
.
BAN
,
depth
=
4
,
),
"
D
"
:
DictObj
(
type
=
EventTypes
.
Name
,
state_key
=
""
,
depth
=
4
,
sender
=
"
@user_id_2:example.com
"
,
),
"
E
"
:
DictObj
(
type
=
EventTypes
.
Message
,
depth
=
5
,
),
},
edges
=
{
"
A
"
:
[
"
START
"
],
"
B
"
:
[
"
A
"
],
"
C
"
:
[
"
B
"
],
"
D
"
:
[
"
B
"
],
"
E
"
:
[
"
C
"
,
"
D
"
]
}
)
store
=
StateGroupStore
()
self
.
store
.
get_state_groups
.
side_effect
=
store
.
get_state_groups
context_store
=
{}
for
event
in
graph
.
walk
():
context
=
yield
self
.
state
.
compute_event_context
(
event
)
store
.
store_state_groups
(
event
,
context
)
context_store
[
event
.
event_id
]
=
context
self
.
assertSetEqual
(
{
"
START
"
,
"
A
"
,
"
B
"
,
"
C
"
},
{
e
.
event_id
for
e
in
context_store
[
"
E
"
].
current_state
.
values
()}
)
@defer.inlineCallbacks
@defer.inlineCallbacks
def
test_annotate_with_old_message
(
self
):
def
test_annotate_with_old_message
(
self
):
event
=
self
.
create_event
(
type
=
"
test_message
"
,
name
=
"
event
"
)
event
=
create_event
(
type
=
"
test_message
"
,
name
=
"
event
"
)
old_state
=
[
old_state
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
]
]
context
=
yield
self
.
state
.
compute_event_context
(
context
=
yield
self
.
state
.
compute_event_context
(
...
@@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
...
@@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
@defer.inlineCallbacks
def
test_annotate_with_old_state
(
self
):
def
test_annotate_with_old_state
(
self
):
event
=
self
.
create_event
(
type
=
"
state
"
,
state_key
=
""
,
name
=
"
event
"
)
event
=
create_event
(
type
=
"
state
"
,
state_key
=
""
,
name
=
"
event
"
)
old_state
=
[
old_state
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
]
]
context
=
yield
self
.
state
.
compute_event_context
(
context
=
yield
self
.
state
.
compute_event_context
(
...
@@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
...
@@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
@defer.inlineCallbacks
def
test_trivial_annotate_message
(
self
):
def
test_trivial_annotate_message
(
self
):
event
=
self
.
create_event
(
type
=
"
test_message
"
,
name
=
"
event
"
)
event
=
create_event
(
type
=
"
test_message
"
,
name
=
"
event
"
)
event
.
prev_events
=
[]
old_state
=
[
old_state
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
]
]
group_name
=
"
group_name_1
"
group_name
=
"
group_name_1
"
...
@@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
...
@@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
@defer.inlineCallbacks
def
test_trivial_annotate_state
(
self
):
def
test_trivial_annotate_state
(
self
):
event
=
self
.
create_event
(
type
=
"
state
"
,
state_key
=
""
,
name
=
"
event
"
)
event
=
create_event
(
type
=
"
state
"
,
state_key
=
""
,
name
=
"
event
"
)
event
.
prev_events
=
[]
old_state
=
[
old_state
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
]
]
group_name
=
"
group_name_1
"
group_name
=
"
group_name_1
"
...
@@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
...
@@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
@defer.inlineCallbacks
def
test_resolve_message_conflict
(
self
):
def
test_resolve_message_conflict
(
self
):
event
=
self
.
create_event
(
type
=
"
test_message
"
,
name
=
"
event
"
)
event
=
create_event
(
type
=
"
test_message
"
,
name
=
"
event
"
)
event
.
prev_events
=
[]
old_state_1
=
[
old_state_1
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
]
]
old_state_2
=
[
old_state_2
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test3
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test3
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test4
"
,
state_key
=
""
),
create_event
(
type
=
"
test4
"
,
state_key
=
""
),
]
]
group_name_1
=
"
group_name_1
"
context
=
yield
self
.
_get_context
(
event
,
old_state_1
,
old_state_2
)
group_name_2
=
"
group_name_2
"
self
.
store
.
get_state_groups
.
return_value
=
{
group_name_1
:
old_state_1
,
group_name_2
:
old_state_2
,
}
context
=
yield
self
.
state
.
compute_event_context
(
event
)
self
.
assertEqual
(
len
(
context
.
current_state
),
5
)
self
.
assertEqual
(
len
(
context
.
current_state
),
5
)
...
@@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase):
...
@@ -181,56 +447,76 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
@defer.inlineCallbacks
def
test_resolve_state_conflict
(
self
):
def
test_resolve_state_conflict
(
self
):
event
=
self
.
create_event
(
type
=
"
test4
"
,
state_key
=
""
,
name
=
"
event
"
)
event
=
create_event
(
type
=
"
test4
"
,
state_key
=
""
,
name
=
"
event
"
)
event
.
prev_events
=
[]
old_state_1
=
[
old_state_1
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
create_event
(
type
=
"
test2
"
,
state_key
=
""
),
]
]
old_state_2
=
[
old_state_2
=
[
self
.
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
),
self
.
create_event
(
type
=
"
test3
"
,
state_key
=
"
2
"
),
create_event
(
type
=
"
test3
"
,
state_key
=
"
2
"
),
self
.
create_event
(
type
=
"
test4
"
,
state_key
=
""
),
create_event
(
type
=
"
test4
"
,
state_key
=
""
),
]
]
group_name_1
=
"
group_name_1
"
context
=
yield
self
.
_get_context
(
event
,
old_state_1
,
old_state_2
)
group_name_2
=
"
group_name_2
"
self
.
store
.
get_state_groups
.
return_value
=
{
group_name_1
:
old_state_1
,
group_name_2
:
old_state_2
,
}
context
=
yield
self
.
state
.
compute_event_context
(
event
)
self
.
assertEqual
(
len
(
context
.
current_state
),
5
)
self
.
assertEqual
(
len
(
context
.
current_state
),
5
)
self
.
assertIsNone
(
context
.
state_group
)
self
.
assertIsNone
(
context
.
state_group
)
def
create_event
(
self
,
name
=
None
,
type
=
None
,
state_key
=
None
):
@defer.inlineCallbacks
self
.
event_id
+=
1
def
test_standard_depth_conflict
(
self
):
event_id
=
str
(
self
.
event_id
)
event
=
create_event
(
type
=
"
test4
"
,
name
=
"
event
"
)
member_event
=
create_event
(
type
=
EventTypes
.
Member
,
state_key
=
"
@user_id:example.com
"
,
content
=
{
"
membership
"
:
Membership
.
JOIN
,
}
)
if
not
name
:
old_state_1
=
[
if
state_key
is
not
None
:
member_event
,
name
=
"
<%s-%s>
"
%
(
type
,
state_key
)
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
,
depth
=
1
),
else
:
]
name
=
"
<%s>
"
%
(
type
,
)
old_state_2
=
[
member_event
,
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
,
depth
=
2
),
]
event
=
Mock
(
name
=
name
,
spec
=
[])
context
=
yield
self
.
_get_context
(
event
,
old_state_1
,
old_state_2
)
event
.
type
=
type
if
state_key
is
not
None
:
self
.
assertEqual
(
old_state_2
[
1
],
context
.
current_state
[(
"
test1
"
,
"
1
"
)])
event
.
state_key
=
state_key
event
.
event_id
=
event_id
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
old_state_1
=
[
member_event
,
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
,
depth
=
2
),
]
old_state_2
=
[
member_event
,
create_event
(
type
=
"
test1
"
,
state_key
=
"
1
"
,
depth
=
1
),
]
context
=
yield
self
.
_get_context
(
event
,
old_state_1
,
old_state_2
)
self
.
assertEqual
(
old_state_1
[
1
],
context
.
current_state
[(
"
test1
"
,
"
1
"
)])
event
.
is_state
=
lambda
:
(
state_key
is
not
None
)
def
_get_context
(
self
,
event
,
old_state_1
,
old_state_2
):
event
.
unsigned
=
{}
group_name_1
=
"
group_name_1
"
group_name_2
=
"
group_name_2
"
event
.
user_id
=
"
@user_id:example.com
"
self
.
store
.
get_state_groups
.
return_value
=
{
event
.
room_id
=
"
!room_id:example.com
"
group_name_1
:
old_state_1
,
group_name_2
:
old_state_2
,
}
return
event
return
self
.
state
.
compute_event_context
(
event
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment