Skip to content
Snippets Groups Projects
Commit b3691810 authored by Tulir Asokan's avatar Tulir Asokan :cat2:
Browse files

Automatically accept portal invites if Matrix puppeting is enabled

parent 500a6d48
Branches
Tags
No related merge requests found
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional from typing import Optional, Iterator
from sqlalchemy import Column, String, Enum, and_ from sqlalchemy import Column, String, Enum, and_
from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.result import RowProxy
...@@ -52,6 +52,11 @@ class Portal(Base): ...@@ -52,6 +52,11 @@ class Portal(Base):
def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']: def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
return cls._select_one_or_none(cls.c.mxid == mxid) return cls._select_one_or_none(cls.c.mxid == mxid)
@classmethod
def get_all_by_receiver(cls, fb_receiver: str) -> Iterator['Portal']:
return cls._select_all(and_(cls.c.fb_receiver == fb_receiver,
cls.c.fb_type == ThreadType.USER))
@property @property
def _edit_identity(self): def _edit_identity(self):
return and_(self.c.fbid == self.fbid, self.c.fb_receiver == self.fb_receiver) return and_(self.c.fbid == self.fbid, self.c.fb_receiver == self.fb_receiver)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Deque, Optional, Tuple, Union, Set, TYPE_CHECKING from typing import Dict, Deque, Optional, Tuple, Union, Set, Iterator, TYPE_CHECKING
from collections import deque from collections import deque
import asyncio import asyncio
import logging import logging
...@@ -233,6 +233,9 @@ class Portal: ...@@ -233,6 +233,9 @@ class Portal:
async def _update_matrix_room(self, source: 'u.User', async def _update_matrix_room(self, source: 'u.User',
info: Optional[ThreadClass] = None) -> None: info: Optional[ThreadClass] = None) -> None:
await self.main_intent.invite_user(self.mxid, source.mxid) await self.main_intent.invite_user(self.mxid, source.mxid)
puppet = p.Puppet.get_by_custom_mxid(source.mxid)
if puppet:
await puppet.intent.ensure_joined(self.mxid)
async def create_matrix_room(self, source: 'u.User', info: Optional[ThreadClass] = None async def create_matrix_room(self, source: 'u.User', info: Optional[ThreadClass] = None
) -> RoomID: ) -> RoomID:
...@@ -266,6 +269,10 @@ class Portal: ...@@ -266,6 +269,10 @@ class Portal:
self.by_mxid[self.mxid] = self self.by_mxid[self.mxid] = self
if not self.is_direct: if not self.is_direct:
await self._update_participants(source, info) await self._update_participants(source, info)
else:
puppet = p.Puppet.get_by_custom_mxid(source.mxid)
if puppet:
await puppet.intent.ensure_joined(self.mxid)
# endregion # endregion
# region Matrix room cleanup # region Matrix room cleanup
...@@ -589,6 +596,14 @@ class Portal: ...@@ -589,6 +596,14 @@ class Portal:
return None return None
@classmethod
def get_all_by_receiver(cls, fb_receiver: str) -> Iterator['Portal']:
for db_portal in DBPortal.get_all_by_receiver(fb_receiver):
try:
yield cls.by_fbid[(db_portal.fbid, db_portal.fb_receiver)]
except KeyError:
yield cls.from_db(db_portal)
@classmethod @classmethod
def get_by_thread(cls, thread: Thread, fb_receiver: Optional[str] = None) -> 'Portal': def get_by_thread(cls, thread: Thread, fb_receiver: Optional[str] = None) -> 'Portal':
return cls.get_by_fbid(thread.uid, fb_receiver, thread.type) return cls.get_by_fbid(thread.uid, fb_receiver, thread.type)
......
...@@ -117,6 +117,13 @@ class Puppet(CustomPuppetMixin): ...@@ -117,6 +117,13 @@ class Puppet(CustomPuppetMixin):
portal = p.Portal.get_by_mxid(room_id) portal = p.Portal.get_by_mxid(room_id)
return portal and portal.fbid != self.fbid return portal and portal.fbid != self.fbid
async def _leave_rooms_with_default_user(self) -> None:
await super()._leave_rooms_with_default_user()
# Make the user join all private chat portals.
await asyncio.gather(*[self.intent.ensure_joined(portal.mxid)
for portal in p.Portal.get_all_by_receiver(self.fbid)
if portal.mxid], loop=self.loop)
def intent_for(self, portal: 'p.Portal') -> IntentAPI: def intent_for(self, portal: 'p.Portal') -> IntentAPI:
if portal.fbid == self.fbid: if portal.fbid == self.fbid:
return self.default_mxid_intent return self.default_mxid_intent
......
...@@ -24,6 +24,9 @@ class SQLStateStore(BaseSQLStateStore): ...@@ -24,6 +24,9 @@ class SQLStateStore(BaseSQLStateStore):
puppet = pu.Puppet.get_by_mxid(user_id, create=False) puppet = pu.Puppet.get_by_mxid(user_id, create=False)
if puppet: if puppet:
return puppet.is_registered return puppet.is_registered
custom_puppet = pu.Puppet.get_by_custom_mxid(user_id)
if custom_puppet:
return True
return super().is_registered(user_id) return super().is_registered(user_id)
def registered(self, user_id: UserID) -> None: def registered(self, user_id: UserID) -> None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment