Skip to content
Snippets Groups Projects
session.py 5.02 KiB
Newer Older
  • Learn to ignore specific revisions
  • Patrick Cloke's avatar
    Patrick Cloke committed
    # This file is licensed under the Affero General Public License (AGPL) version 3.
    #
    
    #  Copyright 2021 The Matrix.org Foundation C.I.C.
    
    Patrick Cloke's avatar
    Patrick Cloke committed
    # Copyright (C) 2023 New Vector, Ltd
    #
    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU Affero General Public License as
    # published by the Free Software Foundation, either version 3 of the
    # License, or (at your option) any later version.
    #
    # See the GNU Affero General Public License for more details:
    # <https://www.gnu.org/licenses/agpl-3.0.html>.
    #
    # Originally licensed under the Apache License, Version 2.0:
    # <http://www.apache.org/licenses/LICENSE-2.0>.
    #
    # [This file includes modifications made by New Vector Limited]
    
    #
    #
    from typing import TYPE_CHECKING
    
    import synapse.util.stringutils as stringutils
    from synapse.api.errors import StoreError
    from synapse.metrics.background_process_metrics import wrap_as_background_process
    from synapse.storage._base import SQLBaseStore, db_to_json
    from synapse.storage.database import (
        DatabasePool,
        LoggingDatabaseConnection,
        LoggingTransaction,
    )
    from synapse.types import JsonDict
    from synapse.util import json_encoder
    
    if TYPE_CHECKING:
        from synapse.server import HomeServer
    
    
    class SessionStore(SQLBaseStore):
        """
        A store for generic session data.
    
        Each type of session should provide a unique type (to separate sessions).
    
        Sessions are automatically removed when they expire.
        """
    
        def __init__(
            self,
            database: DatabasePool,
            db_conn: LoggingDatabaseConnection,
            hs: "HomeServer",
        ):
            super().__init__(database, db_conn, hs)
    
            # Create a background job for culling expired sessions.
    
            if hs.config.worker.run_background_tasks:
    
                self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
    
        async def create_session(
            self, session_type: str, value: JsonDict, expiry_ms: int
        ) -> str:
            """
            Creates a new pagination session for the room hierarchy endpoint.
    
            Args:
                session_type: The type for this session.
                value: The value to store.
                expiry_ms: How long before an item is evicted from the cache
                    in milliseconds. Default is 0, indicating items never get
                    evicted based on time.
    
            Returns:
                The newly created session ID.
    
            Raises:
                StoreError if a unique session ID cannot be generated.
            """
            # autogen a session ID and try to create it. We may clash, so just
            # try a few times till one goes through, giving up eventually.
            attempts = 0
            while attempts < 5:
                session_id = stringutils.random_string(24)
    
                try:
                    await self.db_pool.simple_insert(
                        table="sessions",
                        values={
                            "session_id": session_id,
                            "session_type": session_type,
                            "value": json_encoder.encode(value),
                            "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
                        },
                        desc="create_session",
                    )
    
                    return session_id
                except self.db_pool.engine.module.IntegrityError:
                    attempts += 1
            raise StoreError(500, "Couldn't generate a session ID.")
    
        async def get_session(self, session_type: str, session_id: str) -> JsonDict:
            """
            Retrieve data stored with create_session
    
            Args:
                session_type: The type for this session.
                session_id: The session ID returned from create_session.
    
            Raises:
                StoreError if the session cannot be found.
            """
    
            def _get_session(
                txn: LoggingTransaction, session_type: str, session_id: str, ts: int
            ) -> JsonDict:
                # This includes the expiry time since items are only periodically
                # deleted, not upon expiry.
                select_sql = """
                SELECT value FROM sessions WHERE
                session_type = ? AND session_id = ? AND expiry_time_ms > ?
                """
                txn.execute(select_sql, [session_type, session_id, ts])
                row = txn.fetchone()
    
                if not row:
                    raise StoreError(404, "No session")
    
                return db_to_json(row[0])
    
            return await self.db_pool.runInteraction(
                "get_session",
                _get_session,
                session_type,
                session_id,
                self._clock.time_msec(),
            )
    
        @wrap_as_background_process("delete_expired_sessions")
        async def _delete_expired_sessions(self) -> None:
            """Remove sessions with expiry dates that have passed."""
    
            def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
                sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
                txn.execute(sql, (ts,))
    
            await self.db_pool.runInteraction(
                "delete_expired_sessions",
                _delete_expired_sessions_txn,
                self._clock.time_msec(),
            )