Skip to content
Snippets Groups Projects
Commit 937c1750 authored by Erik Johnston's avatar Erik Johnston
Browse files

Fix up RoomMemberStore to work with the new schema.

parent 6d6a1c34
No related branches found
No related tags found
No related merge requests found
...@@ -64,7 +64,11 @@ class SQLBaseStore(object): ...@@ -64,7 +64,11 @@ class SQLBaseStore(object):
def interaction(txn): def interaction(txn):
cursor = txn.execute(query, args) cursor = txn.execute(query, args)
return decoder(cursor) if decoder:
return decoder(cursor)
else:
return cursor
return self._db_pool.runInteraction(interaction) return self._db_pool.runInteraction(interaction)
def _execut_query(self, query, *args): def _execut_query(self, query, *args):
......
...@@ -31,6 +31,38 @@ logger = logging.getLogger(__name__) ...@@ -31,6 +31,38 @@ logger = logging.getLogger(__name__)
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
@defer.inlineCallbacks
def _store_room_member(self, event):
"""Store a room member in the database.
"""
domain = self.hs.parse_userid(event.target_user_id).domain
yield self._simple_insert(
"room_memberships",
{
"event_id": event.event_id,
"user_id": event.target_user_id,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
}
)
# Update room hosts table
if event.membership == Membership.JOIN:
sql = (
"INSERT OR IGNORE INTO room_hosts (room_id, host) "
"VALUES (?, ?)"
)
yield self._execute(None, sql, room_id, domain)
else:
sql = (
"DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
)
yield self._execute(None, sql, room_id, domain)
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
...@@ -38,36 +70,13 @@ class RoomMemberStore(SQLBaseStore): ...@@ -38,36 +70,13 @@ class RoomMemberStore(SQLBaseStore):
user_id (str): The member's user ID. user_id (str): The member's user ID.
room_id (str): The room the member is in. room_id (str): The room the member is in.
Returns: Returns:
namedtuple: The room member from the database, or None if this Deferred: Results in a MembershipEvent or None.
member does not exist.
""" """
query = RoomMemberTable.select_statement( return self._get_members_by_dict(
"room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
return self._execute(
RoomMemberTable.decode_single_result,
query, room_id, user_id,
)
def store_room_member(self, user_id, sender, room_id, membership, content):
"""Store a room member in the database.
Args:
user_id (str): The member's user ID.
room_id (str): The room in relation to the member.
membership (synapse.api.constants.Membership): The new membership
state.
content (dict): The content of the membership (JSON).
"""
content_json = json.dumps(content)
return self._simple_insert(RoomMemberTable.table_name, dict(
user_id=user_id,
sender=sender,
room_id=room_id, room_id=room_id,
membership=membership, user_id=user_id
content=content_json, )
))
@defer.inlineCallbacks
def get_room_members(self, room_id, membership=None): def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room. """Retrieve the current room member list for a room.
...@@ -79,17 +88,12 @@ class RoomMemberStore(SQLBaseStore): ...@@ -79,17 +88,12 @@ class RoomMemberStore(SQLBaseStore):
Returns: Returns:
list of namedtuples representing the members in this room. list of namedtuples representing the members in this room.
""" """
query = RoomMemberTable.select_statement(
"id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name where = {"room_id": room_id}
+ " WHERE room_id = ? GROUP BY user_id)"
)
res = yield self._execute(
RoomMemberTable.decode_results, query, room_id,
)
# strip memberships which don't match
if membership: if membership:
res = [entry for entry in res if entry.membership == membership] where["membership"] = membership
defer.returnValue(res)
return self._get_members_by_dict(**membership)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
...@@ -106,67 +110,37 @@ class RoomMemberStore(SQLBaseStore): ...@@ -106,67 +110,37 @@ class RoomMemberStore(SQLBaseStore):
return defer.succeed(None) return defer.succeed(None)
args = [user_id] args = [user_id]
membership_placeholder = ["membership=?"] * len(membership_list) args.extend(membership_list)
where_membership = "(" + " OR ".join(membership_placeholder) + ")"
for membership in membership_list:
args.append(membership)
query = ("SELECT room_id, membership FROM room_memberships"
+ " WHERE user_id=? AND " + where_membership
+ " GROUP BY room_id ORDER BY id DESC")
return self._execute(
self.cursor_to_dict, query, *args
)
@defer.inlineCallbacks where_clause "user_id = ? AND (%s)" % (
def get_joined_hosts_for_room(self, room_id): " OR ".join(["membership = ?" for _ in membership_list]),
query = RoomMemberTable.select_statement(
"id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
+ " WHERE room_id = ? GROUP BY user_id)"
) )
res = yield self._execute( return self._get_members_query(where_clause, args)
RoomMemberTable.decode_results, query, room_id,
)
def host_from_user_id_string(user_id):
domain = UserID.from_string(entry.user_id, self.hs).domain
return domain
# strip memberships which don't match
hosts = [
host_from_user_id_string(entry.user_id)
for entry in res
if entry.membership == Membership.JOIN
]
logger.debug("Returning hosts: %s from results: %s", hosts, res) def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol(
defer.returnValue(hosts) "room_hosts",
{"room_id": room_id},
def get_max_room_member_id(self): "host"
return self._simple_max_id(RoomMemberTable.table_name) )
class RoomMemberTable(Table):
table_name = "room_memberships"
fields = [
"id",
"user_id",
"sender",
"room_id",
"membership",
"content"
]
class EntryType(collections.namedtuple("RoomMemberEntry", fields)): def _get_members_by_dict(self, where_dict):
clause = " AND ".join("%s = ?" % k for k in where.keys())
vals = where.values()
return self._get_members_query(clause, vals)
def as_event(self, event_factory): @defer.inlineCallbacks
return event_factory.create_event( def _get_members_query(self, where_clause, where_values):
etype=RoomMemberEvent.TYPE, sql = (
room_id=self.room_id, "SELECT e.* FROM events as e "
target_user_id=self.user_id, "INNER JOIN room_memberships as m "
user_id=self.sender, "ON e.event_id = m.event_id "
content=json.loads(self.content), "INNER JOIN current_state as c "
) "ON m.event_id = c.event_id "
"WHERE %s "
) % (where_clause,)
rows = yield self._execute_query(sql, where_values)
results = [self._parse_event_from_row(r) for r in rows]
defer.returnValue(results)
...@@ -17,7 +17,6 @@ CREATE TABLE IF NOT EXISTS events( ...@@ -17,7 +17,6 @@ CREATE TABLE IF NOT EXISTS events(
ordering INTEGER PRIMARY KEY AUTOINCREMENT, ordering INTEGER PRIMARY KEY AUTOINCREMENT,
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
-- sender TEXT,
room_id TEXT, room_id TEXT,
content TEXT, content TEXT,
unrecognized_keys TEXT unrecognized_keys TEXT
...@@ -57,3 +56,8 @@ CREATE TABLE IF NOT EXISTS rooms( ...@@ -57,3 +56,8 @@ CREATE TABLE IF NOT EXISTS rooms(
is_public INTEGER, is_public INTEGER,
creator TEXT creator TEXT
); );
CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL,
host TEXT NOT NULL
);
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment