Skip to content
Snippets Groups Projects
Commit c6d768b2 authored by Tulir Asokan's avatar Tulir Asokan :cat2:
Browse files

Use per-message profiles for webhooks

parent 6b706005
No related branches found
No related tags found
No related merge requests found
Pipeline #16314 passed
from typing import TYPE_CHECKING
import asyncio
from sqlalchemy import MetaData, Table, Column, Text
from sqlalchemy.engine.base import Engine
from mautrix.types import ContentURI
if TYPE_CHECKING:
from .bot import GitHubBot
class AvatarManager:
bot: 'GitHubBot'
_avatars: dict[str, ContentURI]
_table: Table
_db: Engine
_lock: asyncio.Lock
def __init__(self, bot: 'GitHubBot', metadata: MetaData) -> None:
self.bot = bot
self._db = bot.database
self._table = Table("avatar", metadata,
Column("url", Text, primary_key=True),
Column("mxc", Text, nullable=False))
self._lock = asyncio.Lock()
self._avatars = {}
def load_db(self) -> None:
self._avatars = {url: ContentURI(mxc)
for url, mxc
in self._db.execute(self._table.select())}
async def get_mxc(self, url: str) -> ContentURI:
try:
return self._avatars[url]
except KeyError:
pass
async with self.bot.http.get(url) as resp:
resp.raise_for_status()
data = await resp.read()
async with self._lock:
try:
return self._avatars[url]
except KeyError:
pass
mxc = await self.bot.client.upload_media(data)
self._avatars[url] = mxc
with self._db.begin() as conn:
conn.execute(self._table.insert().values(url=url, mxc=mxc))
return mxc
...@@ -26,6 +26,7 @@ from .client_manager import ClientManager ...@@ -26,6 +26,7 @@ from .client_manager import ClientManager
from .api import GitHubWebhookReceiver from .api import GitHubWebhookReceiver
from .commands import Commands from .commands import Commands
from .config import Config from .config import Config
from .avatar_manager import AvatarManager
class GitHubBot(Plugin): class GitHubBot(Plugin):
...@@ -33,6 +34,7 @@ class GitHubBot(Plugin): ...@@ -33,6 +34,7 @@ class GitHubBot(Plugin):
webhook_receiver: GitHubWebhookReceiver webhook_receiver: GitHubWebhookReceiver
webhook_manager: WebhookManager webhook_manager: WebhookManager
webhook_handler: WebhookHandler webhook_handler: WebhookHandler
avatars: AvatarManager
clients: ClientManager clients: ClientManager
commands: Commands commands: Commands
config: Config config: Config
...@@ -48,6 +50,7 @@ class GitHubBot(Plugin): ...@@ -48,6 +50,7 @@ class GitHubBot(Plugin):
self.webhook_manager = WebhookManager(self.config["webhook_key"], self.webhook_manager = WebhookManager(self.config["webhook_key"],
self.database, metadata) self.database, metadata)
self.webhook_handler = WebhookHandler(bot=self) self.webhook_handler = WebhookHandler(bot=self)
self.avatars = AvatarManager(bot=self, metadata=metadata)
self.webhook_receiver = GitHubWebhookReceiver(handler=self.webhook_handler, self.webhook_receiver = GitHubWebhookReceiver(handler=self.webhook_handler,
secrets=self.webhook_manager, secrets=self.webhook_manager,
global_secret=self.config["global_webhook_secret"]) global_secret=self.config["global_webhook_secret"])
......
...@@ -23,7 +23,7 @@ from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound ...@@ -23,7 +23,7 @@ from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from mautrix.types import UserID, EventID, RoomID from mautrix.types import UserID, EventID, RoomID, ContentURI
Base = declarative_base() Base = declarative_base()
......
...@@ -47,6 +47,7 @@ from ..api.types import ( ...@@ -47,6 +47,7 @@ from ..api.types import (
expand_enum, expand_enum,
ACTION_CLASSES, ACTION_CLASSES,
OTHER_ENUMS, OTHER_ENUMS,
User,
) )
from .manager import WebhookInfo from .manager import WebhookInfo
from .aggregation import PendingAggregation from .aggregation import PendingAggregation
...@@ -182,6 +183,18 @@ class WebhookHandler: ...@@ -182,6 +183,18 @@ class WebhookHandler:
formatted_body=html, formatted_body=html,
body=await parse_html(html.strip()), body=await parse_html(html.strip()),
) )
if hasattr(evt, "sender") and isinstance(evt.sender, User):
mxc = ""
if evt.sender.avatar_url:
try:
mxc = await self.bot.avatars.get_mxc(evt.sender.avatar_url)
except Exception:
self.log.warning("Failed to get avatar URL", exc_info=True)
content["com.beeper.per_message_profile"] = {
"id": str(evt.sender.id),
"displayname": evt.sender.login,
"avatar_url": mxc,
}
content["xyz.maubot.github.webhook"] = { content["xyz.maubot.github.webhook"] = {
"delivery_ids": list(delivery_ids), "delivery_ids": list(delivery_ids),
"event_type": str(evt_type), "event_type": str(evt_type),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment