diff --git a/mautrix_facebook/db/message.py b/mautrix_facebook/db/message.py index c1fecbc840c12b0293ac72a8fd554b042cb5a8c4..df36a1b978534b161b9c85444668f48582c09bbd 100644 --- a/mautrix_facebook/db/message.py +++ b/mautrix_facebook/db/message.py @@ -16,11 +16,13 @@ from typing import Optional, Iterable, List from datetime import datetime -from sqlalchemy import Column, String, DateTime, SmallInteger, UniqueConstraint, and_ +from sqlalchemy import Column, String, SmallInteger, UniqueConstraint, and_ from mautrix.types import RoomID, EventID from mautrix.util.db import Base +from .types import UTCDateTime + class Message(Base): __tablename__ = "message" @@ -31,7 +33,7 @@ class Message(Base): fb_chat: str = Column(String(127), nullable=True) fb_receiver: str = Column(String(127), primary_key=True) index: int = Column(SmallInteger, primary_key=True, default=0) - date: Optional[datetime] = Column(DateTime(timezone=True), nullable=True) + date: Optional[datetime] = Column(UTCDateTime(timezone=True), nullable=True) __table_args__ = (UniqueConstraint("mxid", "mx_room", name="_mx_id_room"),) diff --git a/mautrix_facebook/db/types/__init__.py b/mautrix_facebook/db/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f256ec9da6fa56f7907e94daa53ccb2431b3885e --- /dev/null +++ b/mautrix_facebook/db/types/__init__.py @@ -0,0 +1,30 @@ +from datetime import timezone + +import sqlalchemy.types as types + + +class UTCDateTime(types.TypeDecorator): + """Decorates the SQLAlchemy DateTime type to work with UTCĂ‚ datetimes. + + It supposes we only manipulate UTC datetime. If the timezone is not set when saving or reading + a value, the UTC timezone is set. If a timezone is set, it ensures the datetime is converted to + UTC before saving it. + This is useful when working with SQLite as the SQLalchemy DateTime type loses the timezone + information when saving a datetime on this database. + """ + impl = types.DateTime + + def process_bind_param(self, value, dialect): + if value is not None: + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + elif value.tzinfo != timezone.utc: + value = value.astimezone(timezone.utc) + + return value + + def process_result_value(self, value, dialect): + if value is not None and value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + else: + return value